|
1 | 1 | import torch |
2 | 2 |
|
| 3 | + |
3 | 4 | class NumpyDoubleBackwardMixin: |
4 | | - ''' |
5 | | - Adds a Numpy double backward method to any TensorProduct |
| 5 | + """ |
| 6 | + Adds a Numpy double backward method to any TensorProduct |
6 | 7 | with the forward pass defined in PyTorch and the relevant |
7 | | - derivatives registered. |
8 | | - ''' |
9 | | - def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad): |
| 8 | + derivatives registered. |
| 9 | + """ |
| 10 | + |
| 11 | + def double_backward_cpu( |
| 12 | + self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad |
| 13 | + ): |
10 | 14 | assert self.torch_op |
11 | 15 |
|
12 | | - in1_torch = torch.tensor(in1).to('cuda').requires_grad_(True) |
13 | | - in2_torch = torch.tensor(in2).to('cuda').requires_grad_(True) |
14 | | - weights_torch = torch.tensor(weights).to('cuda').requires_grad_(True) |
15 | | - out_grad_torch = torch.tensor(out_grad).to('cuda').requires_grad_(True) |
16 | | - in1_dgrad_torch = torch.tensor(in1_dgrad).to('cuda') |
17 | | - in2_dgrad_torch = torch.tensor(in2_dgrad).to('cuda') |
18 | | - weights_dgrad_torch = torch.tensor(weights_dgrad).to('cuda') |
| 16 | + in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True) |
| 17 | + in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True) |
| 18 | + weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True) |
| 19 | + out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True) |
| 20 | + in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda") |
| 21 | + in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda") |
| 22 | + weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda") |
19 | 23 | out_torch = self.forward(in1_torch, in2_torch, weights_torch) |
20 | 24 |
|
21 | 25 | in1_grad, in2_grad, weights_grad = torch.autograd.grad( |
22 | 26 | outputs=out_torch, |
23 | 27 | inputs=[in1_torch, in2_torch, weights_torch], |
24 | 28 | grad_outputs=out_grad_torch, |
25 | 29 | create_graph=True, |
26 | | - retain_graph=True |
| 30 | + retain_graph=True, |
27 | 31 | ) |
28 | 32 |
|
29 | 33 | a, b, c, d = torch.autograd.grad( |
30 | 34 | outputs=[in1_grad, in2_grad, weights_grad], |
31 | 35 | inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch], |
32 | | - grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch] |
| 36 | + grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch], |
33 | 37 | ) |
34 | 38 |
|
35 | | - return a.detach().cpu().numpy(), b.detach().cpu().numpy(), c.detach().cpu().numpy(), d.detach().cpu().numpy() |
| 39 | + return ( |
| 40 | + a.detach().cpu().numpy(), |
| 41 | + b.detach().cpu().numpy(), |
| 42 | + c.detach().cpu().numpy(), |
| 43 | + d.detach().cpu().numpy(), |
| 44 | + ) |
36 | 45 |
|
37 | 46 |
|
38 | 47 | class NumpyDoubleBackwardMixinConv: |
39 | | - ''' |
| 48 | + """ |
40 | 49 | Similar, but for fused graph convolution. |
41 | | - ''' |
42 | | - def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, graph): |
| 50 | + """ |
| 51 | + |
| 52 | + def double_backward_cpu( |
| 53 | + self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, graph |
| 54 | + ): |
43 | 55 | assert self.torch_op |
44 | 56 |
|
45 | | - in1_torch = torch.tensor(in1).to('cuda').requires_grad_(True) |
46 | | - in2_torch = torch.tensor(in2).to('cuda').requires_grad_(True) |
47 | | - weights_torch = torch.tensor(weights).to('cuda').requires_grad_(True) |
48 | | - out_grad_torch = torch.tensor(out_grad).to('cuda').requires_grad_(True) |
49 | | - in1_dgrad_torch = torch.tensor(in1_dgrad).to('cuda') |
50 | | - in2_dgrad_torch = torch.tensor(in2_dgrad).to('cuda') |
51 | | - weights_dgrad_torch = torch.tensor(weights_dgrad).to('cuda') |
| 57 | + in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True) |
| 58 | + in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True) |
| 59 | + weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True) |
| 60 | + out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True) |
| 61 | + in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda") |
| 62 | + in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda") |
| 63 | + weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda") |
52 | 64 |
|
53 | 65 | torch_rows = torch.tensor(graph.rows, device="cuda") |
54 | 66 | torch_cols = torch.tensor(graph.cols, device="cuda") |
55 | 67 | torch_transpose_perm = torch.tensor(graph.transpose_perm, device="cuda") |
56 | 68 |
|
57 | | - out_torch = self.forward(in1_torch, in2_torch, weights_torch, torch_rows, torch_cols, torch_transpose_perm) |
| 69 | + out_torch = self.forward( |
| 70 | + in1_torch, |
| 71 | + in2_torch, |
| 72 | + weights_torch, |
| 73 | + torch_rows, |
| 74 | + torch_cols, |
| 75 | + torch_transpose_perm, |
| 76 | + ) |
58 | 77 |
|
59 | 78 | in1_grad, in2_grad, weights_grad = torch.autograd.grad( |
60 | 79 | outputs=out_torch, |
61 | 80 | inputs=[in1_torch, in2_torch, weights_torch], |
62 | 81 | grad_outputs=out_grad_torch, |
63 | 82 | create_graph=True, |
64 | | - retain_graph=True |
| 83 | + retain_graph=True, |
65 | 84 | ) |
66 | 85 |
|
67 | 86 | a, b, c, d = torch.autograd.grad( |
68 | 87 | outputs=[in1_grad, in2_grad, weights_grad], |
69 | 88 | inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch], |
70 | | - grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch] |
| 89 | + grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch], |
71 | 90 | ) |
72 | 91 |
|
73 | | - return a.detach().cpu().numpy(), b.detach().cpu().numpy(), c.detach().cpu().numpy(), d.detach().cpu().numpy() |
74 | | - |
75 | | - |
| 92 | + return ( |
| 93 | + a.detach().cpu().numpy(), |
| 94 | + b.detach().cpu().numpy(), |
| 95 | + c.detach().cpu().numpy(), |
| 96 | + d.detach().cpu().numpy(), |
| 97 | + ) |
0 commit comments