@@ -619,40 +619,33 @@ def calculate_backward_smem(
619619 smem = self .memory_per_warp * warps_per_block ,
620620 )
621621
622- def reorder_weights (self , weights_in , direction , has_batch_dim ):
622+ def weight_reordering_info (self , weights_in , has_batch_dim ):
623623 """
624- Reorders weights from the canonical e3nn form to the
625- form that LoopUnrollTP can ingest. Can also reorder the parameters
626- of a dense neural network layer that produces the weight matrix.
627-
628- If has_batch_dim is true, the first dimension of the input weight matrix
629- is treated as the batch dimension.
624+ Calculates all shapes, slices, and permutation info to reorder
625+ weights.
630626 """
631- import torch # TODO-someday: no need to specialize this to PyTorch
627+ batch_dim = weights_in .shape [0 ]
628+ reorder_specs = []
632629
633- weights_out = torch .zeros_like (weights_in )
634- assert direction in ["forward" , "backward" ]
635630 for i , child_inst in enumerate (self .problem_splitter .new_instructions ):
636631 parent_start , parent_end = (
637632 child_inst .parent_weights_start ,
638633 child_inst .parent_weights_end ,
639634 )
640635 parent_shape = list (child_inst .parent_weights_shape )
636+ parent_range = [slice (parent_start , parent_end )]
641637
642638 child_start , child_end , child_shape = (
643639 self .updated_config .weight_range_and_shape_for_instruction (i )
644640 )
645-
646- parent_range , child_range = (
647- [slice (parent_start , parent_end )],
648- [slice (child_start , child_end )],
649- )
641+ child_range = [slice (child_start , child_end )]
642+
650643 weights_subrange = child_inst .weights_subrange
651- batch_dim = weights_in . shape [ 0 ]
644+
652645 reshape_size = [- 1 ]
653646 transpose_perm = None
654-
655647 connection_mode = self .updated_config .instructions [i ].connection_mode
648+
656649 if connection_mode == "uvu" :
657650 transpose_perm = [1 , 0 ]
658651 elif connection_mode == "uvw" :
@@ -662,50 +655,27 @@ def reorder_weights(self, weights_in, direction, has_batch_dim):
662655 child_range = [slice (0 , batch_dim )] + child_range
663656 parent_range = [slice (0 , batch_dim )] + parent_range
664657 parent_shape = [batch_dim ] + parent_shape
658+
665659 child_shape = [batch_dim ] + list (child_shape )
666660 weights_subrange = [slice (0 , batch_dim )] + child_inst .weights_subrange
667661 reshape_size = [batch_dim ] + reshape_size
668- transpose_perm = [0 ] + [i + 1 for i in transpose_perm ]
669-
670- if direction == "forward" :
671- sliced_weights = weights_in [tuple (parent_range )].reshape (parent_shape )[
672- tuple (weights_subrange )
673- ]
674- weights_out [tuple (child_range )] = sliced_weights .permute (
675- transpose_perm
676- ).reshape (reshape_size )
677- elif direction == "backward" :
678- transpose_child_shape = [child_shape [i ] for i in transpose_perm ]
679- sliced_weights = (
680- weights_in [tuple (child_range )]
681- .reshape (transpose_child_shape )
682- .permute (transpose_perm )
683- )
684- weights_out [tuple (parent_range )].reshape (parent_shape )[
685- tuple (weights_subrange )
686- ] = sliced_weights .flatten ().reshape (child_shape )
687-
688- return weights_out
689-
690- def reorder_weights_numpy (self , weights_in , direction , has_batch_dim ):
691- import torch
692-
693- weights_in = torch .from_numpy (weights_in .copy ())
694- result = self .reorder_weights (weights_in , direction , has_batch_dim )
695- return result .detach ().cpu ().numpy ().copy ()
696-
697- def reorder_weights_from_e3nn (self , weights_in , has_batch_dim ):
698- import torch
699-
700- if isinstance (weights_in , np .ndarray ):
701- return self .reorder_weights_numpy (weights_in , "forward" , has_batch_dim )
702- elif isinstance (weights_in , torch .Tensor ):
703- return self .reorder_weights (weights_in , "forward" , has_batch_dim )
704-
705- def reorder_weights_to_e3nn (self , weights_in , has_batch_dim ):
706- import torch
707-
708- if isinstance (weights_in , np .ndarray ):
709- return self .reorder_weights_numpy (weights_in , "backward" , has_batch_dim )
710- elif isinstance (weights_in , torch .Tensor ):
711- return self .reorder_weights (weights_in , "backward" , has_batch_dim )
662+
663+ if transpose_perm is not None :
664+ transpose_perm = [0 ] + [k + 1 for k in transpose_perm ]
665+
666+ transpose_child_shape = None
667+ if transpose_perm is not None :
668+ transpose_child_shape = [child_shape [k ] for k in transpose_perm ]
669+
670+ reorder_specs .append ({
671+ "parent_range" : tuple (parent_range ),
672+ "parent_shape" : parent_shape ,
673+ "weights_subrange" : tuple (weights_subrange ),
674+ "child_range" : tuple (child_range ),
675+ "child_shape" : child_shape ,
676+ "transpose_perm" : transpose_perm ,
677+ "reshape_size" : reshape_size ,
678+ "transpose_child_shape" : transpose_child_shape ,
679+ })
680+
681+ return reorder_specs
0 commit comments