Skip to content

Commit 496f81d

Browse files
MACE Integration Utilities (#96)
* Added torch datatype converter function for convenience. * Added weight_numel to torch wrappers. * Minor fix to TensorProduct. * Reverted mace_driver for now. * Updated README. * Updated README with instructions. * Updated the README one more time.
1 parent 1f9c53e commit 496f81d

6 files changed

Lines changed: 41 additions & 19 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ nvidia-mathdx*
2525
.vscode/*
2626
*.ncu-rep
2727
mace_dev
28+
mace_oeq_integration
2829
valid_indices*
2930

3031
*.xyz

README.md

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,25 @@
44
[[Supported Tensor Products]](#tensor-products-we-accelerate)
55
[[Citation and Acknowledgements]](#citation-and-acknowledgements)
66

7-
OpenEquivariance is a kernel generator for the Clebsch-Gordon tensor product,
7+
OpenEquivariance is a CUDA and HIP kernel generator for the Clebsch-Gordon tensor product,
88
a key kernel in rotation-equivariant deep neural networks.
99
It implements some of the tensor products
10-
that [e3nn](https://e3nn.org/) supports that are
10+
that [e3nn](https://e3nn.org/) supports
1111
commonly found in graph neural networks
1212
(e.g. [Nequip](https://github.com/mir-group/nequip) or
13-
[MACE](https://github.com/ACEsuit/mace)). To get started, install our package via
13+
[MACE](https://github.com/ACEsuit/mace)). To get
14+
started, ensure that you have GCC 9+ on your system
15+
and install our package via
1416

1517
```bash
1618
pip install git+https://github.com/PASSIONLab/OpenEquivariance
1719
```
1820

19-
We provide up to an order of magnitude acceleration over e3nn
20-
and up to ~2x speedup over
21-
[NVIDIA cuEquivariance](https://github.com/NVIDIA/cuEquivariance),
22-
which has a closed-source kernel package. We also offer fused
23-
equivariant graph convolutions that can reduce
21+
We provide up to an order of magnitude acceleration over e3nn perform on par with the latest
22+
version of [NVIDIA cuEquivariance](https://github.com/NVIDIA/cuEquivariance),
23+
which has a closed-source kernel package.
24+
We also offer fused equivariant graph
25+
convolutions that can reduce
2426
computation and memory consumption significantly.
2527

2628
We currently support NVIDIA GPUs and just added beta support on AMD GPUs for
@@ -124,22 +126,24 @@ print(torch.norm(Z))
124126
arbitrary order.
125127

126128
## Installation
127-
We currently support Linux systems only. We recommend that you use
128-
`conda` or `mamba` to set up a Python environment for installation.
129-
130-
After activating an environment of your choice, run
129+
We currently support Linux systems only.
130+
Before installation and the first library import,
131+
ensure that the command
132+
`c++ --version` returns GCC 9+; if not, set the
133+
`CC` and `CXX` environment variables to point to
134+
valid compilers. On NERSC Perlmutter,
135+
`module load gcc` will set up your environment
136+
correctly.
137+
138+
To install, run
131139
```bash
132140
pip install git+https://github.com/PASSIONLab/OpenEquivariance
133141
```
134142
After installation, the very first library
135-
import will trigger a build of a C++ extension we use.
143+
import will trigger a build of a C++ extension we use,
144+
which takes longer than usual.
136145
All subsequent imports will not retrigger compilation.
137146

138-
If you encounter problems with installation, let us
139-
know by filing a bug and try a development build (see
140-
below). After installation, you should be able
141-
to run the example above.
142-
143147
## Replicating our benchmarks
144148
To run our benchmark suite, you'll also need the following packages:
145149
- `e3nn`,

openequivariance/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from openequivariance.implementations.e3nn_lite import TPProblem, Irreps
66
from openequivariance.implementations.TensorProduct import TensorProduct
77
from openequivariance.implementations.convolution.TensorProductConv import TensorProductConv
8+
from openequivariance.implementations.utils import torch_to_oeq_dtype
89

910
__version__ = version("openequivariance")
1011

openequivariance/implementations/TensorProduct.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class TensorProduct(torch.nn.Module, LoopUnrollTP):
1010
def __init__(self, problem, torch_op=True):
1111
torch.nn.Module.__init__(self)
1212
LoopUnrollTP.__init__(self, problem, torch_op)
13+
self.weight_numel = problem.weight_numel
1314

1415
@staticmethod
1516
def name():

openequivariance/implementations/convolution/TensorProductConv.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def __init__(self, config, idx_dtype=np.int64, torch_op=True, deterministic=Fals
1515
torch_op=torch_op, deterministic=deterministic)
1616

1717
self.dummy_transpose_perm = torch.zeros(1, dtype=torch.int64, device='cuda')
18+
self.weight_numel = self.config.weight_numel
1819

1920
if extlib.TORCH_COMPILE:
2021
self.forward = self.forward_deterministic if deterministic else self.forward_atomic

openequivariance/implementations/utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,18 @@ def filter_and_analyze_problem(problem):
7070
result = {
7171
"is_uvw": problem.instructions[0].connection_mode == "uvw",
7272
}
73-
return result
73+
return result
74+
75+
def torch_to_oeq_dtype(torch_dtype):
76+
global torch
77+
import torch
78+
79+
"""
80+
Converts torch dtype to oeq dtype
81+
"""
82+
if torch_dtype == torch.float32:
83+
return np.float32
84+
elif torch_dtype == torch.float64:
85+
return np.float64
86+
else:
87+
raise ValueError("Unsupported torch dtype!")

0 commit comments

Comments
 (0)