Skip to content

Commit 961ed0b

Browse files
authored
Modified CartesianTrace to work for tensors with rank >2. Added rank 3 tests to test_cartesian_operators.py which pass with the changes to operators.py, but did not pass before. (#292)
1 parent 70bd15e commit 961ed0b

2 files changed

Lines changed: 45 additions & 2 deletions

File tree

dedalus/core/operators.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,8 +1829,11 @@ class CartesianTrace(Trace):
18291829
def subproblem_matrix(self, subproblem):
18301830
dim = self.coordsys.dim
18311831
trace = np.ravel(np.eye(dim))
1832-
# Assume all components have the same n_size
1833-
eye = sparse.identity(subproblem.coeff_size(self.domain), self.dtype, format='csr')
1832+
# Kronecker up identity for remaining tensor components
1833+
n_eye = prod(cs.dim for cs in self.tensorsig)
1834+
# Kronecker up identity for coeff size
1835+
n_eye *= subproblem.coeff_size(self.domain)
1836+
eye = sparse.identity(n_eye, self.dtype, format='csr')
18341837
matrix = sparse.kron(trace, eye)
18351838
return matrix
18361839

dedalus/tests/test_cartesian_operators.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,23 @@ def test_trace_explicit(basis, N, dealias, dtype, layout):
147147
assert np.allclose(g[layout], np.trace(f[layout]))
148148

149149

150+
@pytest.mark.parametrize('basis', [build_FF, build_FC, build_CC, build_FFF, build_FFC])
151+
@pytest.mark.parametrize('N', N_range)
152+
@pytest.mark.parametrize('dealias', dealias_range)
153+
@pytest.mark.parametrize('dtype', dtype_range)
154+
@pytest.mark.parametrize('layout', ['c', 'g'])
155+
def test_trace_rank3_explicit(basis, N, dealias, dtype, layout):
156+
"""Test explicit evaluation of trace operator for correctness."""
157+
c, d, b, r = basis(N, dealias, dtype)
158+
# Random tensor field
159+
f = d.TensorField((c,c,c), bases=b)
160+
f.fill_random(layout='g')
161+
# Evaluate trace
162+
f.change_layout(layout)
163+
g = d3.trace(f).evaluate()
164+
assert np.allclose(g[layout], np.trace(f[layout]))
165+
166+
150167
@pytest.mark.parametrize('basis', [build_FF, build_FC, build_CC, build_FFF, build_FFC])
151168
@pytest.mark.parametrize('N', N_range)
152169
@pytest.mark.parametrize('dealias', dealias_range)
@@ -170,6 +187,29 @@ def test_trace_implicit(basis, N, dealias, dtype):
170187
assert np.allclose(u['c'], f['c'])
171188

172189

190+
@pytest.mark.parametrize('basis', [build_FF, build_FC, build_CC, build_FFF, build_FFC])
191+
@pytest.mark.parametrize('N', N_range)
192+
@pytest.mark.parametrize('dealias', dealias_range)
193+
@pytest.mark.parametrize('dtype', dtype_range)
194+
def test_trace_rank3_implicit(basis, N, dealias, dtype):
195+
"""Test implicit evaluation of trace operator for correctness."""
196+
c, d, b, r = basis(N, dealias, dtype)
197+
# Random scalar field
198+
f = d.VectorField(c, bases=b)
199+
f.fill_random(layout='g')
200+
# Trace LBVP
201+
u = d.VectorField(c, bases=b)
202+
I = d.TensorField((c,c))
203+
dim = len(r)
204+
for i in range(dim):
205+
I['g'][i,i] = 1
206+
problem = d3.LBVP([u], namespace=locals())
207+
problem.add_equation("trace(I*u) = dim*f")
208+
solver = problem.build_solver()
209+
solver.solve()
210+
assert np.allclose(u['c'], f['c'])
211+
212+
173213
@pytest.mark.parametrize('basis', [build_FF, build_FC, build_CC, build_FFF, build_FFC])
174214
@pytest.mark.parametrize('N', N_range)
175215
@pytest.mark.parametrize('dealias', dealias_range)

0 commit comments

Comments
 (0)