Skip to content

Commit 9de117e

Browse files
committed
More refactoring.
1 parent 95783ef commit 9de117e

20 files changed

Lines changed: 396 additions & 381 deletions

openequivariance/openequivariance/_torch/CUETensorProduct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from openequivariance.core.TensorProductBase import TensorProductBase
88
from openequivariance.core.e3nn_lite import TPProblem
9-
from openequivariance.benchmark.logging_utils import getLogger
9+
from openequivariance.benchmark.logging import getLogger
1010
from openequivariance.benchmark.tpp_creation_utils import (
1111
ChannelwiseTPP,
1212
FullyConnectedTPProblem,

openequivariance/openequivariance/_torch/E3NNTensorProduct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from openequivariance.core.TensorProductBase import TensorProductBase
1313
from openequivariance.core.e3nn_lite import TPProblem
14-
from openequivariance.benchmark.logging_utils import getLogger
14+
from openequivariance.benchmark.logging import getLogger
1515
from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin
1616

1717
TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning")

openequivariance/openequivariance/_torch/TensorProduct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from openequivariance._torch import extlib
44
import torch
55
from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum
6-
from openequivariance.benchmark.logging_utils import getLogger
6+
from openequivariance.benchmark.logging import getLogger
77
from openequivariance._torch.utils import (
88
reorder_torch,
99
string_to_tensor,

openequivariance/openequivariance/_torch/TensorProductConv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
enum_to_torch_dtype,
2424
)
2525

26-
from openequivariance.benchmark.logging_utils import getLogger
26+
from openequivariance.benchmark.logging import getLogger
2727
from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv
2828

2929
logger = getLogger()

openequivariance/openequivariance/_torch/extlib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch
1010

11-
from openequivariance.benchmark.logging_utils import getLogger
11+
from openequivariance.benchmark.logging import getLogger
1212

1313
oeq_root = str(Path(__file__).parent.parent.parent)
1414

openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
import numpy as np
77

88
import openequivariance as oeq
9-
from openequivariance.benchmark.logging_utils import getLogger
9+
from openequivariance.benchmark.correctness import (
10+
correctness_backward_conv,
11+
correctness_double_backward_conv,
12+
correctness_forward_conv,
13+
)
14+
from openequivariance.benchmark.logging import getLogger
1015
from openequivariance.core.ConvolutionBase import CoordGraph
1116
from openequivariance.benchmark.benchmark_utils import NpEncoder
1217

@@ -90,7 +95,8 @@ def run(
9095

9196
if direction == "forward":
9297
if correctness:
93-
correctness = conv.test_correctness_forward(
98+
correctness = correctness_forward_conv(
99+
conv,
94100
graph,
95101
thresh=self.correctness_threshold,
96102
prng_seed=self.prng_seed,
@@ -105,7 +111,8 @@ def run(
105111

106112
if direction == "backward":
107113
if correctness:
108-
correctness = conv.test_correctness_backward(
114+
correctness = correctness_backward_conv(
115+
conv,
109116
graph,
110117
thresh=self.correctness_threshold,
111118
prng_seed=self.prng_seed,
@@ -120,8 +127,9 @@ def run(
120127

121128
if direction == "double_backward":
122129
if correctness:
123-
correctness = conv.test_correctness_double_backward(
124-
self.graph,
130+
correctness = correctness_double_backward_conv(
131+
conv,
132+
graph,
125133
thresh=self.correctness_threshold,
126134
prng_seed=self.prng_seed,
127135
reference_implementation=self.reference_impl,

openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from openequivariance._torch.extlib import DeviceProp
1111
from openequivariance.core.TensorProductBase import TensorProductBase
1212

13-
from openequivariance.benchmark.logging_utils import getLogger, bcolors
13+
from openequivariance.benchmark.logging import getLogger, bcolors
1414
from openequivariance.core.e3nn_lite import TPProblem
15-
from openequivariance.benchmark.correctness_utils import (
15+
from openequivariance.benchmark.correctness import (
1616
correctness_forward,
1717
correctness_backward,
1818
correctness_double_backward,

openequivariance/openequivariance/benchmark/benchmark_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from openequivariance.core.TensorProductBase import TensorProductBase
1616
from openequivariance.core.e3nn_lite import TPProblem
1717
from openequivariance._torch.CUETensorProduct import CUETensorProduct
18-
from openequivariance.benchmark.logging_utils import getLogger, bcolors
18+
from openequivariance.benchmark.logging import getLogger, bcolors
1919

2020
logger = getLogger()
2121

0 commit comments

Comments
 (0)