@@ -147,6 +147,23 @@ def test_trace_explicit(basis, N, dealias, dtype, layout):
147147 assert np .allclose (g [layout ], np .trace (f [layout ]))
148148
149149
150+ @pytest .mark .parametrize ('basis' , [build_FF , build_FC , build_CC , build_FFF , build_FFC ])
151+ @pytest .mark .parametrize ('N' , N_range )
152+ @pytest .mark .parametrize ('dealias' , dealias_range )
153+ @pytest .mark .parametrize ('dtype' , dtype_range )
154+ @pytest .mark .parametrize ('layout' , ['c' , 'g' ])
155+ def test_trace_rank3_explicit (basis , N , dealias , dtype , layout ):
156+ """Test explicit evaluation of trace operator for correctness."""
157+ c , d , b , r = basis (N , dealias , dtype )
158+ # Random tensor field
159+ f = d .TensorField ((c ,c ,c ), bases = b )
160+ f .fill_random (layout = 'g' )
161+ # Evaluate trace
162+ f .change_layout (layout )
163+ g = d3 .trace (f ).evaluate ()
164+ assert np .allclose (g [layout ], np .trace (f [layout ]))
165+
166+
150167@pytest .mark .parametrize ('basis' , [build_FF , build_FC , build_CC , build_FFF , build_FFC ])
151168@pytest .mark .parametrize ('N' , N_range )
152169@pytest .mark .parametrize ('dealias' , dealias_range )
@@ -170,6 +187,29 @@ def test_trace_implicit(basis, N, dealias, dtype):
170187 assert np .allclose (u ['c' ], f ['c' ])
171188
172189
190+ @pytest .mark .parametrize ('basis' , [build_FF , build_FC , build_CC , build_FFF , build_FFC ])
191+ @pytest .mark .parametrize ('N' , N_range )
192+ @pytest .mark .parametrize ('dealias' , dealias_range )
193+ @pytest .mark .parametrize ('dtype' , dtype_range )
194+ def test_trace_rank3_implicit (basis , N , dealias , dtype ):
195+ """Test implicit evaluation of trace operator for correctness."""
196+ c , d , b , r = basis (N , dealias , dtype )
197+ # Random scalar field
198+ f = d .VectorField (c , bases = b )
199+ f .fill_random (layout = 'g' )
200+ # Trace LBVP
201+ u = d .VectorField (c , bases = b )
202+ I = d .TensorField ((c ,c ))
203+ dim = len (r )
204+ for i in range (dim ):
205+ I ['g' ][i ,i ] = 1
206+ problem = d3 .LBVP ([u ], namespace = locals ())
207+ problem .add_equation ("trace(I*u) = dim*f" )
208+ solver = problem .build_solver ()
209+ solver .solve ()
210+ assert np .allclose (u ['c' ], f ['c' ])
211+
212+
173213@pytest .mark .parametrize ('basis' , [build_FF , build_FC , build_CC , build_FFF , build_FFC ])
174214@pytest .mark .parametrize ('N' , N_range )
175215@pytest .mark .parametrize ('dealias' , dealias_range )
0 commit comments