Skip to content

Commit 44ecf0e

Browse files
committed
Compacted everything.
1 parent 4806748 commit 44ecf0e

2 files changed

Lines changed: 20 additions & 20 deletions

File tree

openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ def double_backward_cpu(
3636
grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch],
3737
)
3838

39-
a_np = a.detach().cpu().numpy()
40-
b_np = b.detach().cpu().numpy()
41-
c_np = c.detach().cpu().numpy()
42-
d_np = d.detach().cpu().numpy()
43-
44-
return (a_np, b_np, c_np, d_np)
39+
return (
40+
a.detach().cpu().numpy(),
41+
b.detach().cpu().numpy(),
42+
c.detach().cpu().numpy(),
43+
d.detach().cpu().numpy(),
44+
)
4545

4646

4747
class NumpyDoubleBackwardMixinConv:
@@ -89,9 +89,9 @@ def double_backward_cpu(
8989
grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch],
9090
)
9191

92-
a_np = a.detach().cpu().numpy()
93-
b_np = b.detach().cpu().numpy()
94-
c_np = c.detach().cpu().numpy()
95-
d_np = d.detach().cpu().numpy()
96-
97-
return (a_np, b_np, c_np, d_np)
92+
return (
93+
a.detach().cpu().numpy(),
94+
b.detach().cpu().numpy(),
95+
c.detach().cpu().numpy(),
96+
d.detach().cpu().numpy(),
97+
)

openequivariance/openequivariance/_torch/TensorProduct.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
import numpy as np
2-
import torch
3-
1+
from openequivariance.core.LoopUnrollTP import LoopUnrollTP
42
from openequivariance import TPProblem
53
from openequivariance._torch import extlib
6-
from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin
4+
import torch
5+
from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum
6+
from openequivariance.benchmark.logging_utils import getLogger
77
from openequivariance._torch.utils import (
8-
enum_to_torch_dtype,
98
reorder_torch,
109
string_to_tensor,
10+
enum_to_torch_dtype,
1111
)
12-
from openequivariance.benchmark.logging_utils import getLogger
13-
from openequivariance.core.LoopUnrollTP import LoopUnrollTP
14-
from openequivariance.core.utils import dtype_to_enum, torch_to_oeq_dtype
12+
from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin
13+
14+
import numpy as np
1515

1616
logger = getLogger()
1717

0 commit comments

Comments
 (0)