@@ -5974,28 +5974,24 @@ class PolarAzimuthalComponent(operators.AzimuthalComponent):
59745974 basis_type = IntervalBasis
59755975
59765976 def subproblem_matrix (self , subproblem ):
5977- # I'm not sure how to generalize this to higher order tensors, since we do
5978- # not have spin_weights for the S1 basis.
5979- matrix = np .array ([[1 ,0 ]])
5977+ operand = self .args [0 ]
5978+ input_dim = len (operand .tensorsig )
5979+ output_dim = len (self .tensorsig )
5980+ matrix = []
5981+ for output in range (2 ** output_dim ):
5982+ index_out = np .unravel_index (output , [2 ]* output_dim )
5983+ matrix_row = []
5984+ for input in range (2 ** input_dim ):
5985+ index_in = np .unravel_index (input , [2 ]* input_dim )
5986+ if tuple (index_in [:self .index ] + index_in [self .index + 1 :]) == index_out and index_in [self .index ] == 0 :
5987+ matrix_row .append (1 )
5988+ else :
5989+ matrix_row .append (0 )
5990+ matrix .append (matrix_row )
5991+ matrix = np .array (matrix )
59805992 if self .dtype == np .float64 :
59815993 # Block-diag for sin/cos parts for real dtype
59825994 matrix = sparse .kron (matrix , sparse .eye (2 ))
5983-
5984- # operand = self.args[0]
5985- # basis = self.domain.get_basis(self.coordsys)
5986- # S_in = basis.spin_weights(operand.tensorsig)
5987- # S_out = basis.spin_weights(self.tensorsig)
5988- #
5989- # matrix = []
5990- # for spinindex_out, spintotal_out in np.ndenumerate(S_out):
5991- # matrix_row = []
5992- # for spinindex_in, spintotal_in in np.ndenumerate(S_in):
5993- # if tuple(spinindex_in[:self.index] + spinindex_in[self.index+1:]) == spinindex_out and spinindex_in[self.index] == 2:
5994- # matrix_row.append( 1 )
5995- # else:
5996- # matrix_row.append( 0 )
5997- # matrix.append(matrix_row)
5998- # matrix = np.array(matrix)
59995995 return matrix
60005996
60015997 def operate (self , out ):
@@ -6012,28 +6008,24 @@ class PolarRadialComponent(operators.RadialComponent):
60126008 basis_type = IntervalBasis
60136009
60146010 def subproblem_matrix (self , subproblem ):
6015- # I'm not sure how to generalize this to higher order tensors, since we do
6016- # not have spin_weights for the S1 basis.
6017- matrix = np .array ([[0 ,1 ]])
6011+ operand = self .args [0 ]
6012+ input_dim = len (operand .tensorsig )
6013+ output_dim = len (self .tensorsig )
6014+ matrix = []
6015+ for output in range (2 ** output_dim ):
6016+ index_out = np .unravel_index (output , [2 ]* output_dim )
6017+ matrix_row = []
6018+ for input in range (2 ** input_dim ):
6019+ index_in = np .unravel_index (input , [2 ]* input_dim )
6020+ if tuple (index_in [:self .index ] + index_in [self .index + 1 :]) == index_out and index_in [self .index ] == 1 :
6021+ matrix_row .append (1 )
6022+ else :
6023+ matrix_row .append (0 )
6024+ matrix .append (matrix_row )
6025+ matrix = np .array (matrix )
60186026 if self .dtype == np .float64 :
60196027 # Block-diag for sin/cos parts for real dtype
60206028 matrix = sparse .kron (matrix , sparse .eye (2 ))
6021-
6022- # operand = self.args[0]
6023- # basis = self.domain.get_basis(self.coordsys)
6024- # S_in = basis.spin_weights(operand.tensorsig)
6025- # S_out = basis.spin_weights(self.tensorsig)
6026- #
6027- # matrix = []
6028- # for spinindex_out, spintotal_out in np.ndenumerate(S_out):
6029- # matrix_row = []
6030- # for spinindex_in, spintotal_in in np.ndenumerate(S_in):
6031- # if tuple(spinindex_in[:self.index] + spinindex_in[self.index+1:]) == spinindex_out and spinindex_in[self.index] == 2:
6032- # matrix_row.append( 1 )
6033- # else:
6034- # matrix_row.append( 0 )
6035- # matrix.append(matrix_row)
6036- # matrix = np.array(matrix)
60376029 return matrix
60386030
60396031 def operate (self , out ):
0 commit comments