1- # Examples from the README
1+ # ruff: noqa: E402
2+ # Examples from the README
23import logging
34from openequivariance .benchmark .logging_utils import getLogger
5+
46logger = getLogger ()
57logger .setLevel (logging .ERROR )
68
7- # UVU Tensor Product
9+ # UVU Tensor Product
810# ===============================
911import torch
1012import e3nn .o3 as o3
1113
12- gen = torch .Generator (device = ' cuda' )
14+ gen = torch .Generator (device = " cuda" )
1315
1416batch_size = 1000
15- X_ir , Y_ir , Z_ir = o3 .Irreps ("1x2e" ), o3 .Irreps ("1x3e" ), o3 .Irreps ("1x2e" )
16- X = torch .rand (batch_size , X_ir .dim , device = ' cuda' , generator = gen )
17- Y = torch .rand (batch_size , Y_ir .dim , device = ' cuda' , generator = gen )
17+ X_ir , Y_ir , Z_ir = o3 .Irreps ("1x2e" ), o3 .Irreps ("1x3e" ), o3 .Irreps ("1x2e" )
18+ X = torch .rand (batch_size , X_ir .dim , device = " cuda" , generator = gen )
19+ Y = torch .rand (batch_size , Y_ir .dim , device = " cuda" , generator = gen )
1820
19- instructions = [(0 , 0 , 0 , "uvu" , True )]
21+ instructions = [(0 , 0 , 0 , "uvu" , True )]
2022
21- tp_e3nn = o3 .TensorProduct (X_ir , Y_ir , Z_ir , instructions ,
22- shared_weights = False , internal_weights = False ).to ('cuda' )
23- W = torch .rand (batch_size , tp_e3nn .weight_numel , device = 'cuda' , generator = gen )
23+ tp_e3nn = o3 .TensorProduct (
24+ X_ir , Y_ir , Z_ir , instructions , shared_weights = False , internal_weights = False
25+ ).to ("cuda" )
26+ W = torch .rand (batch_size , tp_e3nn .weight_numel , device = "cuda" , generator = gen )
2427
2528Z = tp_e3nn (X , Y , W )
2629print (torch .norm (Z ))
2932# ===============================
3033import openequivariance as oeq
3134
32- problem = oeq .TPProblem (X_ir , Y_ir , Z_ir , instructions , shared_weights = False , internal_weights = False )
35+ problem = oeq .TPProblem (
36+ X_ir , Y_ir , Z_ir , instructions , shared_weights = False , internal_weights = False
37+ )
3338tp_fast = oeq .TensorProduct (problem , torch_op = True )
3439
35- Z = tp_fast (X , Y , W ) # Reuse X, Y, W from earlier
40+ Z = tp_fast (X , Y , W ) # Reuse X, Y, W from earlier
3641print (torch .norm (Z ))
3742# ===============================
3843
4449
4550# Receiver, sender indices for message passing GNN
4651edge_index = EdgeIndex (
47- [[0 , 1 , 1 , 2 ], # Receiver
48- [1 , 0 , 2 , 1 ]], # Sender
49- device = 'cuda' ,
50- dtype = torch .long )
51-
52- X = torch .rand (node_ct , X_ir .dim , device = 'cuda' , generator = gen )
53- Y = torch .rand (nonzero_ct , Y_ir .dim , device = 'cuda' , generator = gen )
54- W = torch .rand (nonzero_ct , problem .weight_numel , device = 'cuda' , generator = gen )
55-
56- tp_conv = oeq .TensorProductConv (problem , torch_op = True , deterministic = False ) # Reuse problem from earlier
57- Z = tp_conv .forward (X , Y , W , edge_index [0 ], edge_index [1 ]) # Z has shape [node_ct, z_ir.dim]
52+ [
53+ [0 , 1 , 1 , 2 ], # Receiver
54+ [1 , 0 , 2 , 1 ], # Sender
55+ ],
56+ device = "cuda" ,
57+ dtype = torch .long ,
58+ )
59+
60+ X = torch .rand (node_ct , X_ir .dim , device = "cuda" , generator = gen )
61+ Y = torch .rand (nonzero_ct , Y_ir .dim , device = "cuda" , generator = gen )
62+ W = torch .rand (nonzero_ct , problem .weight_numel , device = "cuda" , generator = gen )
63+
64+ tp_conv = oeq .TensorProductConv (
65+ problem , torch_op = True , deterministic = False
66+ ) # Reuse problem from earlier
67+ Z = tp_conv .forward (
68+ X , Y , W , edge_index [0 ], edge_index [1 ]
69+ ) # Z has shape [node_ct, z_ir.dim]
5870print (torch .norm (Z ))
5971# ===============================
6072
6173# ===============================
62- _ , sender_perm = edge_index .sort_by ("col" ) # Sort by sender index
63- edge_index , receiver_perm = edge_index .sort_by ("row" ) # Sort by receiver index
74+ _ , sender_perm = edge_index .sort_by ("col" ) # Sort by sender index
75+ edge_index , receiver_perm = edge_index .sort_by ("row" ) # Sort by receiver index
6476
6577# Now we can use the faster deterministic algorithm
66- tp_conv = oeq .TensorProductConv (problem , torch_op = True , deterministic = True )
67- Z = tp_conv .forward (X , Y [receiver_perm ], W [receiver_perm ], edge_index [0 ], edge_index [1 ], sender_perm )
78+ tp_conv = oeq .TensorProductConv (problem , torch_op = True , deterministic = True )
79+ Z = tp_conv .forward (
80+ X , Y [receiver_perm ], W [receiver_perm ], edge_index [0 ], edge_index [1 ], sender_perm
81+ )
6882print (torch .norm (Z ))
69- # ===============================
83+ # ===============================
0 commit comments