Skip to content

Commit f1e8754

Browse files
committed
Fix TypeError with empty axes in FFT functions
1 parent 1c0abc5 commit f1e8754

2 files changed

Lines changed: 62 additions & 4 deletions

File tree

mkl_fft/_fft_utils.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,10 @@ def _iter_complementary(x, axes, func, kwargs, result):
233233
m_ind = _flat_to_multi(ind, sub_shape)
234234
for k1, k2 in zip(dual_ind, m_ind):
235235
sl[k1] = k2
236+
tsl = tuple(sl)
237+
236238
if np.issubdtype(x.dtype, np.complexfloating):
237-
func(x[tuple(sl)], **kwargs, out=result[tuple(sl)])
239+
func(x[tsl], **kwargs, out=result[tsl])
238240
else:
239241
# For c2c FFT, if the input is real, half of the output is the
240242
# complex conjugate of the other half. Instead of upcasting the
@@ -247,7 +249,7 @@ def _iter_complementary(x, axes, func, kwargs, result):
247249
# array appeared in the second half of the NumPy output array,
248250
# while the equivalent element in the NumPy array was the conjugate
249251
# of the mkl_fft output array.
250-
np.copyto(result[tuple(sl)], func(x[tuple(sl)], **kwargs))
252+
np.copyto(result[tsl], func(x[tsl], **kwargs))
251253

252254
return result
253255

@@ -260,7 +262,6 @@ def _iter_fftnd(
260262
direction=+1,
261263
scale_function=lambda ind: 1.0,
262264
):
263-
a = np.asarray(a)
264265
s, axes = _init_nd_shape_and_axes(a, s, axes)
265266

266267
# Combine the two, but in reverse, to end with the first axis given.
@@ -412,8 +413,20 @@ def _c2c_fftnd_impl(
412413
out=out,
413414
)
414415
else:
416+
x = np.asarray(x)
417+
418+
# Fast path: FFT over no axes is identity (just type conversion, no scaling)
419+
_, xa = _cook_nd_args(x, s, axes)
420+
if len(xa) == 0:
421+
if out is None:
422+
out = x.astype(dtype=_output_dtype(x.dtype), copy=True)
423+
else:
424+
_validate_out_array(out, x, _output_dtype(x.dtype))
425+
np.copyto(out, x)
426+
# No scaling applied - identity transform has no normalization
427+
return out
428+
415429
if _complementary and x.dtype in valid_dtypes:
416-
x = np.asarray(x)
417430
if out is None:
418431
res = np.empty_like(x, dtype=_output_dtype(x.dtype))
419432
else:

mkl_fft/tests/test_fftnd.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,48 @@ def test_out_strided(axes, func):
317317
expected = getattr(np.fft, func)(x, axes=axes, out=out)
318318

319319
assert_allclose(result, expected, strict=True)
320+
321+
322+
@pytest.mark.parametrize(
323+
"dtype", [np.float32, np.float64, np.complex64, np.complex128]
324+
)
325+
@pytest.mark.parametrize("shape", [(3, 4), (5,), (2, 3, 4), (10, 20)])
326+
@pytest.mark.parametrize("norm", [None, "ortho", "forward", "backward"])
327+
@pytest.mark.parametrize("func", ["fftn", "ifftn", "fft2", "ifft2"])
328+
def test_empty_axes(dtype, shape, norm, func):
329+
if np.issubdtype(dtype, np.complexfloating):
330+
x = rnd.random(shape).astype(dtype) + 1j * rnd.random(shape).astype(
331+
dtype
332+
)
333+
else:
334+
x = rnd.random(shape).astype(dtype)
335+
336+
# Test fftn with axes=()
337+
result = getattr(mkl_fft, func)(x, axes=(), norm=norm)
338+
expected = getattr(np.fft, func)(x, axes=(), norm=norm)
339+
340+
rtol, atol = _get_rtol_atol(result)
341+
assert_allclose(result, expected, rtol=rtol, atol=atol, strict=True)
342+
343+
344+
@pytest.mark.parametrize(
345+
"dtype", [np.float32, np.float64, np.complex64, np.complex128]
346+
)
347+
@pytest.mark.parametrize("func", ["fftn", "ifftn", "fft2", "ifft2"])
348+
def test_empty_axes_with_out(dtype, func):
349+
if np.issubdtype(dtype, np.complexfloating):
350+
x = rnd.random((3, 4)).astype(dtype) + 1j * rnd.random((3, 4)).astype(
351+
dtype
352+
)
353+
else:
354+
x = rnd.random((3, 4)).astype(dtype)
355+
356+
out_dtype = np.dtype(dtype).char.upper()
357+
out = np.empty_like(x, dtype=out_dtype)
358+
result = getattr(mkl_fft, func)(x, axes=(), out=out)
359+
expected = getattr(np.fft, func)(x, axes=())
360+
361+
# Result should be written to out
362+
assert result is out
363+
rtol, atol = _get_rtol_atol(result)
364+
assert_allclose(result, expected, rtol=rtol, atol=atol, strict=True)

0 commit comments

Comments
 (0)