admin管理员组

文章数量:1201361

     def _forward(self, hidden_states_a, hidden_states_b, temb):
        xxxx
        return hidden_states_a, hidden_states_b

     def forward(self, hidden_states_a, hidden_states_b, temb):
        if self.cuda_graph is None:
            hidden_states_a_bk, hidden_states_b_bk, temb_bk = hidden_states_a.clone(), hidden_states_b.clone(), temb.clone()
            self.cuda_graph = torch.cuda.CUDAGraph()
            #input static placeholder
            self.graph_input['hidden_states_a'] = hidden_states_a_bk
            self.graph_input['hidden_states_b'] = hidden_states_b_bk
            self.graph_input['temb'] = temb_bk
            
            #capture the graph
            with torch.cuda.graph(self.cuda_graph):
                hidden_states_a_bk, hidden_states_b_bk = self._forward(hidden_states_a_bk, hidden_states_b_bk, temb_bk)
            torch.cuda.synchronize()

            # output placeholder
            self.graph_output['hidden_states_a'] = hidden_states_a_bk
            self.graph_output['hidden_states_b'] = hidden_states_b_bk

            return self.execute_model(hidden_states_a, hidden_states_b, temb)
        else:
            return self.execute_model(hidden_states_a, hidden_states_b, temb)
    def execute_model(self, hidden_states_a, hidden_states_b, temb):
        self.graph_input['hidden_states_a'].copy_(hidden_states_a)
        self.graph_input['hidden_states_b'].copy_(hidden_states_b)
        self.graph_input['temb'].copy_(temb)
        self.cuda_graph.replay()
        return self.graph_output['hidden_states_a'], self.graph_output['hidden_states_b']

Here is my code. In the original forward function, the inputs are hidden_states_a, hidden_states_b, and temb, and the outputs are hidden_states_a, hidden_states_b. That is to say, the first two inputs are also the outputs. I renamed the original forward function to _forward. In the new forward function, when it runs for the first time, self.cuda_graph is None. I clone all three input tensors, then put them into placeHolder. After that, I capture the CUDA graph through the self._forward function and put the outputs into placeHolder. In the execute_model function, I simply replay the CUDA graph. Are there any bugs in this approach? After writing it this way, the code runs faster, but the results are incorrect.

what's wrong happen?

本文标签: pytorchA problem about cuda graph when input and output are the sameStack Overflow