Skip to content

Commit 4735501

Browse files
lecoanetkburns
authored andcommitted
Modified polar component operators so they work on tensors.
1 parent e3cab89 commit 4735501

1 file changed

Lines changed: 30 additions & 38 deletions

File tree

dedalus/core/basis.py

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5974,28 +5974,24 @@ class PolarAzimuthalComponent(operators.AzimuthalComponent):
59745974
basis_type = IntervalBasis
59755975

59765976
def subproblem_matrix(self, subproblem):
5977-
# I'm not sure how to generalize this to higher order tensors, since we do
5978-
# not have spin_weights for the S1 basis.
5979-
matrix = np.array([[1,0]])
5977+
operand = self.args[0]
5978+
input_dim = len(operand.tensorsig)
5979+
output_dim = len(self.tensorsig)
5980+
matrix = []
5981+
for output in range(2**output_dim):
5982+
index_out = np.unravel_index(output, [2]*output_dim)
5983+
matrix_row = []
5984+
for input in range(2**input_dim):
5985+
index_in = np.unravel_index(input, [2]*input_dim)
5986+
if tuple(index_in[:self.index] + index_in[self.index+1:]) == index_out and index_in[self.index] == 0:
5987+
matrix_row.append(1)
5988+
else:
5989+
matrix_row.append(0)
5990+
matrix.append(matrix_row)
5991+
matrix = np.array(matrix)
59805992
if self.dtype == np.float64:
59815993
# Block-diag for sin/cos parts for real dtype
59825994
matrix = sparse.kron(matrix, sparse.eye(2))
5983-
5984-
# operand = self.args[0]
5985-
# basis = self.domain.get_basis(self.coordsys)
5986-
# S_in = basis.spin_weights(operand.tensorsig)
5987-
# S_out = basis.spin_weights(self.tensorsig)
5988-
#
5989-
# matrix = []
5990-
# for spinindex_out, spintotal_out in np.ndenumerate(S_out):
5991-
# matrix_row = []
5992-
# for spinindex_in, spintotal_in in np.ndenumerate(S_in):
5993-
# if tuple(spinindex_in[:self.index] + spinindex_in[self.index+1:]) == spinindex_out and spinindex_in[self.index] == 2:
5994-
# matrix_row.append( 1 )
5995-
# else:
5996-
# matrix_row.append( 0 )
5997-
# matrix.append(matrix_row)
5998-
# matrix = np.array(matrix)
59995995
return matrix
60005996

60015997
def operate(self, out):
@@ -6012,28 +6008,24 @@ class PolarRadialComponent(operators.RadialComponent):
60126008
basis_type = IntervalBasis
60136009

60146010
def subproblem_matrix(self, subproblem):
6015-
# I'm not sure how to generalize this to higher order tensors, since we do
6016-
# not have spin_weights for the S1 basis.
6017-
matrix = np.array([[0,1]])
6011+
operand = self.args[0]
6012+
input_dim = len(operand.tensorsig)
6013+
output_dim = len(self.tensorsig)
6014+
matrix = []
6015+
for output in range(2**output_dim):
6016+
index_out = np.unravel_index(output, [2]*output_dim)
6017+
matrix_row = []
6018+
for input in range(2**input_dim):
6019+
index_in = np.unravel_index(input, [2]*input_dim)
6020+
if tuple(index_in[:self.index] + index_in[self.index+1:]) == index_out and index_in[self.index] == 1:
6021+
matrix_row.append(1)
6022+
else:
6023+
matrix_row.append(0)
6024+
matrix.append(matrix_row)
6025+
matrix = np.array(matrix)
60186026
if self.dtype == np.float64:
60196027
# Block-diag for sin/cos parts for real dtype
60206028
matrix = sparse.kron(matrix, sparse.eye(2))
6021-
6022-
# operand = self.args[0]
6023-
# basis = self.domain.get_basis(self.coordsys)
6024-
# S_in = basis.spin_weights(operand.tensorsig)
6025-
# S_out = basis.spin_weights(self.tensorsig)
6026-
#
6027-
# matrix = []
6028-
# for spinindex_out, spintotal_out in np.ndenumerate(S_out):
6029-
# matrix_row = []
6030-
# for spinindex_in, spintotal_in in np.ndenumerate(S_in):
6031-
# if tuple(spinindex_in[:self.index] + spinindex_in[self.index+1:]) == spinindex_out and spinindex_in[self.index] == 2:
6032-
# matrix_row.append( 1 )
6033-
# else:
6034-
# matrix_row.append( 0 )
6035-
# matrix.append(matrix_row)
6036-
# matrix = np.array(matrix)
60376029
return matrix
60386030

60396031
def operate(self, out):

0 commit comments

Comments
 (0)