admin管理员组文章数量:1392116
Assume I have a network created as follows:
p = torch.nn.Sequential(torch.nn.Linear(self.inputSize, self.outputSize))
I know that I can print the network with:
print(p)
and get:
Sequential(
(0): Linear(in_features=22, out_features=3, bias=True)
)
I want to pretty print the network thus, I can use
for name, module in p.named_children():
print(f'{name:>10} {module}')
to print each network layer's name. For the example network above I'd get:
0 Linear(in_features=22, out_features=3, bias=True)
But how to I get the 'Sequential' part? It's the nn module container class so is the only method to do this to dissect the class name returned by type (<class 'torch.nn.modules.container.Sequential'>)?
Assume I have a network created as follows:
p = torch.nn.Sequential(torch.nn.Linear(self.inputSize, self.outputSize))
I know that I can print the network with:
print(p)
and get:
Sequential(
(0): Linear(in_features=22, out_features=3, bias=True)
)
I want to pretty print the network thus, I can use
for name, module in p.named_children():
print(f'{name:>10} {module}')
to print each network layer's name. For the example network above I'd get:
0 Linear(in_features=22, out_features=3, bias=True)
But how to I get the 'Sequential' part? It's the nn module container class so is the only method to do this to dissect the class name returned by type (<class 'torch.nn.modules.container.Sequential'>)?
Share Improve this question edited Mar 11 at 17:34 James 36.8k4 gold badges51 silver badges77 bronze badges asked Mar 11 at 17:22 JKompJKomp 676 bronze badges1 Answer
Reset to default 1You can get the class of the top-level container of the network using the .__class__
attribute. To get just the name, use .__class__.__name__
.
print(p.__class__.__name__)
for name, module in p.named_children():
print(f' {name:<2} {module}')
Prints:
Sequential
0 Linear(in_features=22, out_features=3, bias=True)
本文标签: pythonHow to extract a pytorch network container class nameStack Overflow
版权声明:本文标题:python - How to extract a pytorch network container class name - Stack Overflow 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1744779806a2624638.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论