Skip to content

Commit fcb0f3e

Browse files
asgloverAustin Glover
andauthored
Pytorch Determinism Warning (#155)
* add determinism alerts * add determinism tests --------- Co-authored-by: Austin Glover <austin_glover@berkeley.com>
1 parent 8d30ca2 commit fcb0f3e

2 files changed

Lines changed: 85 additions & 4 deletions

File tree

openequivariance/extension/libtorch_tp_jit.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,11 @@ torch::Tensor jit_conv_forward(
464464
check_tensor(rows, {nnz}, k.idx_dtype, "rows");
465465
check_tensor(cols, {nnz}, k.idx_dtype, "cols");
466466

467-
if (k.deterministic)
467+
if (k.deterministic){
468468
check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm");
469-
469+
} else {
470+
at::globalContext().alertNotDeterministic("OpenEquivariance_conv_atomic_forward");
471+
}
470472
if (k.shared_weights)
471473
check_tensor(W, {k.weight_numel}, k.weight_dtype, "W");
472474
else
@@ -519,8 +521,11 @@ tuple<torch::Tensor, torch::Tensor, torch::Tensor> jit_conv_backward(
519521
check_tensor(rows, {nnz}, k.idx_dtype, "rows");
520522
check_tensor(cols, {nnz}, k.idx_dtype, "cols");
521523

522-
if (k.deterministic)
524+
if (k.deterministic){
523525
check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm");
526+
} else {
527+
at::globalContext().alertNotDeterministic("OpenEquivariance_conv_atomic_backward");
528+
}
524529

525530
if (k.shared_weights)
526531
check_tensor(W, {k.weight_numel}, k.weight_dtype, "W");
@@ -587,8 +592,11 @@ tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> jit_conv_doubl
587592
check_tensor(rows, {nnz}, k.idx_dtype, "rows");
588593
check_tensor(cols, {nnz}, k.idx_dtype, "cols");
589594

590-
if (k.deterministic)
595+
if (k.deterministic) {
591596
check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm");
597+
} else {
598+
at::globalContext().alertNotDeterministic("OpenEquivariance_conv_atomic_double_backward");
599+
}
592600

593601
if (k.shared_weights) {
594602
check_tensor(W, {k.weight_numel}, k.weight_dtype, "W");

tests/torch_determinism_test.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import pytest
2+
import torch
3+
4+
from openequivariance import TPProblem, TensorProductConv
5+
6+
from e3nn import o3
7+
from torch_geometric import EdgeIndex
8+
9+
10+
@pytest.fixture
11+
def gen():
12+
return torch.Generator(device="cuda")
13+
14+
15+
@pytest.fixture
16+
def edge_index():
17+
return EdgeIndex(
18+
data=[
19+
[0, 1, 1, 2], # Receiver
20+
[1, 0, 2, 1], # Sender
21+
],
22+
sparse_size=(3, 4),
23+
device="cuda",
24+
dtype=torch.long,
25+
)
26+
27+
28+
@pytest.fixture
29+
def tpp():
30+
X_ir = o3.Irreps("1x2e")
31+
Y_ir = o3.Irreps("1x3e")
32+
Z_ir = o3.Irreps("1x2e")
33+
instructions = [(0, 0, 0, "uvu", True)]
34+
return TPProblem(
35+
X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
36+
)
37+
38+
39+
@pytest.fixture
40+
def conv_buffers(edge_index, tpp, gen):
41+
X = torch.rand(
42+
edge_index.num_rows, tpp.irreps_in1.dim, device="cuda", generator=gen
43+
)
44+
Y = torch.rand(
45+
edge_index.num_cols, tpp.irreps_in2.dim, device="cuda", generator=gen
46+
)
47+
W = torch.rand(edge_index.num_cols, tpp.weight_numel, device="cuda", generator=gen)
48+
return (X, Y, W, edge_index[0], edge_index[1])
49+
50+
51+
@pytest.fixture
52+
def tp_conv(tpp):
53+
return TensorProductConv(tpp, deterministic=False)
54+
55+
56+
def test_no_response(tp_conv, conv_buffers):
57+
torch.use_deterministic_algorithms(False)
58+
tp_conv(*conv_buffers)
59+
60+
61+
def test_warning(tp_conv, conv_buffers, capfd):
62+
torch.use_deterministic_algorithms(True, warn_only=True)
63+
tp_conv(*conv_buffers)
64+
65+
captured = capfd.readouterr()
66+
assert "Warning" in captured.err
67+
assert "does not have a deterministic implementation" in captured.err
68+
69+
70+
def test_error(tp_conv, conv_buffers):
71+
torch.use_deterministic_algorithms(True, warn_only=False)
72+
with pytest.raises(RuntimeError):
73+
tp_conv(*conv_buffers)

0 commit comments

Comments
 (0)