admin管理员组

文章数量:1347193

I am wondering if anyone has had positive experiences generating Audio/Speech models visualizations using torchview.

In fact, best practices to generating such diagrams using other model architectures (Whisper, Hubert, etc.) or even other visualization tools are welcome.

My question: How to control which layers and layer types are shown?

So far, the best results I have are reproduced as follows:

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from transformers import (
    Wav2Vec2Processor,,
    AutoConfig,
)

model_name_or_path = "facebook/wav2vec2-large-xlsr-53"
processor = Wav2Vec2Processor.from_pretrained(model_name_or_path,)

'''
'my_model' below is loaded from a transformers.models.wav2vec2.modeling_wav2vec2Wav2Vec2PreTrainedModel derived class, using an AutoConfig config object.
'''

input_size = processor.feature_extractor.sampling_rate

for depth_value in range(0, 3):
    font_props = {'family':'serif','color':'darkblue','size':15}
    graph = draw_graph(my_model,
                   input_size=(1,input_size),
                   hide_module_functions=False,
                   depth=depth_value,
                   hide_inner_tensors=True,
                   expand_nested=True,
                   show_shapes=True,
                   device='meta',   # 'meta','cpu','cuda',
                   graph_dir = 'BT',
                   roll=True,
                   # rollout_modules=False,
                   # rollout_activations=False,
                   # hide_module_names=True,
                   # hide_node_names=True,
                   # ignored_modules=["Dropout", "mean", "transpose"],
                   # collapse_residual=True, 
                   # collapse_batch_norm=True, 
                   # group_modules=True,  
                   # group_convolutions=True,  
                   # filter_layers=['Conv1d', 'Linear'],
                   # exclude_layers=['Dropout'],
                   )
    graph.visual_graph.render(filename="model_image", format='png')

    img = mpimg.imread("model_image.png")
    plt.figure(figsize=(10, 40))  # Adjust size
    plt.title('Nivel '+str(depth_value), fontdict = font_props )
    plt.imshow(img)
    plt.axis("off")  # Hide axes
    plt.show()

The commented arguments on the 'draw_graph' call where suggested by AI/LLM based help. But they do not work at all.

Other suggestions include editing a copy of the Wav2Vec2Model itself before passing it to draw_graph. But that apparently leaves little space for flexibly using the same code to other models in the future. And, as of today, it is simpler and faster for me to just edit the SVG output.

本文标签: matplotlibWav2Vec2 Diagrams on torchviewStack Overflow