admin管理员组文章数量:1201361
def _forward(self, hidden_states_a, hidden_states_b, temb):
xxxx
return hidden_states_a, hidden_states_b
def forward(self, hidden_states_a, hidden_states_b, temb):
if self.cuda_graph is None:
hidden_states_a_bk, hidden_states_b_bk, temb_bk = hidden_states_a.clone(), hidden_states_b.clone(), temb.clone()
self.cuda_graph = torch.cuda.CUDAGraph()
#input static placeholder
self.graph_input['hidden_states_a'] = hidden_states_a_bk
self.graph_input['hidden_states_b'] = hidden_states_b_bk
self.graph_input['temb'] = temb_bk
#capture the graph
with torch.cuda.graph(self.cuda_graph):
hidden_states_a_bk, hidden_states_b_bk = self._forward(hidden_states_a_bk, hidden_states_b_bk, temb_bk)
torch.cuda.synchronize()
# output placeholder
self.graph_output['hidden_states_a'] = hidden_states_a_bk
self.graph_output['hidden_states_b'] = hidden_states_b_bk
return self.execute_model(hidden_states_a, hidden_states_b, temb)
else:
return self.execute_model(hidden_states_a, hidden_states_b, temb)
def execute_model(self, hidden_states_a, hidden_states_b, temb):
self.graph_input['hidden_states_a'].copy_(hidden_states_a)
self.graph_input['hidden_states_b'].copy_(hidden_states_b)
self.graph_input['temb'].copy_(temb)
self.cuda_graph.replay()
return self.graph_output['hidden_states_a'], self.graph_output['hidden_states_b']
Here is my code. In the original forward function, the inputs are hidden_states_a
, hidden_states_b
, and temb
, and the outputs are hidden_states_a
, hidden_states_b
. That is to say, the first two inputs are also the outputs. I renamed the original forward function to _forward
. In the new forward function, when it runs for the first time, self.cuda_graph
is None. I clone all three input tensors, then put them into placeHolder. After that, I capture the CUDA graph through the self._forward
function and put the outputs into placeHolder. In the execute_model function, I simply replay the CUDA graph. Are there any bugs in this approach? After writing it this way, the code runs faster, but the results are incorrect.
what's wrong happen?
本文标签: pytorchA problem about cuda graph when input and output are the sameStack Overflow
版权声明:本文标题:pytorch - A problem about cuda graph when input and output are the same - Stack Overflow 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1738562728a2099550.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论