Skip to content

Commit 7105978

Browse files
asgloverAustin Glover
andauthored
adding e3tools problems to production problems (#147)
* adding e3tools problems to production problems * make weights external (for now) * distribute tests appropriately into shared and non-shared --------- Co-authored-by: Austin Glover <austin_glover@berkeley.com>
1 parent 54efff0 commit 7105978

2 files changed

Lines changed: 29 additions & 2 deletions

File tree

openequivariance/benchmark/problems.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,27 @@ def nequip_problems():
149149
),
150150
]
151151
]
152+
153+
154+
def e3tools_problems():
155+
return [
156+
FCTPP(in1, in2, out, label=label, shared_weights=sw, internal_weights=iw)
157+
for (in1, in2, out, label, sw, iw) in [
158+
(
159+
"64x0e+16x1o",
160+
"1x0e+1x1o",
161+
"80x0e+16x1o",
162+
"e3tools_conv",
163+
False,
164+
False,
165+
),
166+
(
167+
"64x0e+16x1o",
168+
"1x0e+1x1o",
169+
"64x0e+16x1o",
170+
"e3tools_transformer",
171+
True,
172+
False, # Should be true, we don't support currently
173+
),
174+
]
175+
]

tests/conv_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from openequivariance.benchmark.problems import (
1313
mace_problems,
1414
diffdock_problems,
15+
e3tools_problems,
1516
)
1617

1718

@@ -123,7 +124,9 @@ def test_tp_double_bwd(self, conv_object, graph):
123124

124125

125126
class TestProductionModels(ConvCorrectness):
126-
production_model_tpps = mace_problems() + diffdock_problems()
127+
production_model_tpps = (
128+
mace_problems() + diffdock_problems() + [e3tools_problems()[0]]
129+
)
127130

128131
@pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class")
129132
def problem(self, request, dtype):
@@ -219,7 +222,7 @@ def problem(self, request, dtype):
219222

220223

221224
class TestAtomicSharedWeights(ConvCorrectness):
222-
problems = [mace_problems()[0], diffdock_problems()[0]]
225+
problems = [mace_problems()[0], diffdock_problems()[0], e3tools_problems()[1]]
223226

224227
def thresh(self, direction):
225228
return {

0 commit comments

Comments
 (0)