Skip to content

Commit 7b36ab3

Browse files
authored
Merge pull request #293 from IntelPython/return-input-array-on-empty-axes
Return input array unchanged when axes=() (ignore out parameter)
2 parents 5a5f1c2 + 0d0bd05 commit 7b36ab3

2 files changed

Lines changed: 32 additions & 14 deletions

File tree

mkl_fft/_fft_utils.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -519,16 +519,13 @@ def _c2c_fftnd_impl(
519519
else:
520520
x = np.asarray(x)
521521

522-
# Fast path: FFT over no axes is complete identity (preserve dtype)
522+
# Fast path: FFT over no axes is a complete identity operation.
523+
# Returns the input unchanged (same object, no copy), preserving
524+
# dtype and avoiding any FFT computation. The out parameter is
525+
# ignored to match NumPy behavior.
523526
_, xa = _cook_nd_args(x, s, axes)
524527
if len(xa) == 0:
525-
if out is None:
526-
out = x.copy()
527-
else:
528-
_validate_out_array(out, x, x.dtype)
529-
np.copyto(out, x)
530-
# No scaling applied - identity transform has no normalization
531-
return out
528+
return x
532529

533530
if _complementary and x.dtype in valid_dtypes:
534531
if out is None:

mkl_fft/tests/test_fftnd.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -353,12 +353,33 @@ def test_empty_axes_with_out(dtype, func):
353353
else:
354354
x = rnd.random((3, 4)).astype(dtype)
355355

356-
# For axes=(), output dtype should match input dtype (identity transform)
356+
# NumPy ignores out parameter when axes=() and returns input
357357
out = np.empty_like(x, dtype=dtype)
358358
result = getattr(mkl_fft, func)(x, axes=(), out=out)
359-
expected = getattr(np.fft, func)(x, axes=())
359+
expected = getattr(np.fft, func)(x, axes=(), out=out)
360360

361-
# Result should be written to out
362-
assert result is out
363-
rtol, atol = _get_rtol_atol(result)
364-
assert_allclose(result, expected, rtol=rtol, atol=atol, strict=True)
361+
# Result should be the input array (out parameter ignored)
362+
assert result is x, f"{func} with axes=() should return input (ignore out)"
363+
assert expected is x, "NumPy should also return input"
364+
assert result is expected
365+
366+
367+
@pytest.mark.parametrize(
368+
"dtype", [np.float32, np.float64, np.complex64, np.complex128]
369+
)
370+
@pytest.mark.parametrize("func", ["fftn", "ifftn", "fft2", "ifft2"])
371+
def test_empty_axes_returns_same_object(dtype, func):
372+
if np.issubdtype(dtype, np.complexfloating):
373+
x = rnd.random((3, 4)).astype(dtype) + 1j * rnd.random((3, 4)).astype(
374+
dtype
375+
)
376+
else:
377+
x = rnd.random((3, 4)).astype(dtype)
378+
379+
# Without out parameter, should return the same object
380+
result = getattr(mkl_fft, func)(x, axes=())
381+
382+
# Verify it's the exact same object (identity check)
383+
assert (
384+
result is x
385+
), f"{func} with axes=() should return the same object, not a copy"

0 commit comments

Comments
 (0)