Skip to content

Commit 7dcd887

Browse files
committed
More bugfixes.
1 parent 379fd28 commit 7dcd887

3 files changed

Lines changed: 6 additions & 4 deletions

File tree

openequivariance/openequivariance/core/ComputationSchedule.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def __init__(self, input, mult_threshold):
301301
path_normalization="none",
302302
internal_weights=False,
303303
shared_weights=input.shared_weights,
304+
layout=input.layout,
304305
)
305306

306307
assert self.output.weight_numel == input.weight_numel
@@ -595,6 +596,7 @@ def calculate_backward_smem(
595596
path_normalization="none",
596597
internal_weights=False,
597598
shared_weights=config.shared_weights,
599+
layout=config.layout,
598600
)
599601

600602
weight_offset = 0

openequivariance/openequivariance/core/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,11 @@ def transpose_irrep_layout(
240240
if src_layout == "ir_mul" and dst_layout == "mul_ir":
241241
out[..., seg.start : seg.stop] = block.reshape(
242242
*block.shape[:-1], dim, mul
243-
).reshape(*block.shape[:-1], mul * dim)
243+
).swapaxes(-1, -2).reshape(*block.shape[:-1], mul * dim)
244244
elif src_layout == "mul_ir" and dst_layout == "ir_mul":
245245
out[..., seg.start : seg.stop] = block.reshape(
246246
*block.shape[:-1], mul, dim
247-
).reshape(*block.shape[:-1], dim * mul)
247+
).swapaxes(-1, -2).reshape(*block.shape[:-1], dim * mul)
248248
else:
249249
raise ValueError(
250250
f"Unsupported layout transpose: {src_layout} -> {dst_layout}"

tests/batch_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,14 +273,14 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
273273
return tp, tp.config
274274

275275

276-
class TestMulIrLayoutMACE(TPCorrectness):
276+
class TestIrMulLayoutMACE(TPCorrectness):
277277
production_model_tpps = mace_problems()
278278

279279
@pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class")
280280
def problem(self, request, dtype):
281281
problem = request.param.clone()
282282
problem.irrep_dtype, problem.weight_dtype = dtype, dtype
283-
problem.layout = "mul_ir"
283+
problem.layout = "ir_mul"
284284
return problem
285285

286286

0 commit comments

Comments
 (0)