@@ -353,12 +353,33 @@ def test_empty_axes_with_out(dtype, func):
353353 else :
354354 x = rnd .random ((3 , 4 )).astype (dtype )
355355
356- # For axes=(), output dtype should match input dtype (identity transform)
356+ # NumPy ignores out parameter when axes=() and returns input
357357 out = np .empty_like (x , dtype = dtype )
358358 result = getattr (mkl_fft , func )(x , axes = (), out = out )
359- expected = getattr (np .fft , func )(x , axes = ())
359+ expected = getattr (np .fft , func )(x , axes = (), out = out )
360360
361- # Result should be written to out
362- assert result is out
363- rtol , atol = _get_rtol_atol (result )
364- assert_allclose (result , expected , rtol = rtol , atol = atol , strict = True )
361+ # Result should be the input array (out parameter ignored)
362+ assert result is x , f"{ func } with axes=() should return input (ignore out)"
363+ assert expected is x , "NumPy should also return input"
364+ assert result is expected
365+
366+
367+ @pytest .mark .parametrize (
368+ "dtype" , [np .float32 , np .float64 , np .complex64 , np .complex128 ]
369+ )
370+ @pytest .mark .parametrize ("func" , ["fftn" , "ifftn" , "fft2" , "ifft2" ])
371+ def test_empty_axes_returns_same_object (dtype , func ):
372+ if np .issubdtype (dtype , np .complexfloating ):
373+ x = rnd .random ((3 , 4 )).astype (dtype ) + 1j * rnd .random ((3 , 4 )).astype (
374+ dtype
375+ )
376+ else :
377+ x = rnd .random ((3 , 4 )).astype (dtype )
378+
379+ # Without out parameter, should return the same object
380+ result = getattr (mkl_fft , func )(x , axes = ())
381+
382+ # Verify it's the exact same object (identity check)
383+ assert (
384+ result is x
385+ ), f"{ func } with axes=() should return the same object, not a copy"
0 commit comments