Skip to content

Commit 98b1964

Browse files
committed
rename s_equiv_to_none and consolidate shape-checking logic
1 parent 2b05835 commit 98b1964

1 file changed

Lines changed: 10 additions & 19 deletions

File tree

mkl_fft/_fft_utils.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,12 @@ def _check_norm(norm):
4343
)
4444

4545

46-
def _check_shapes_for_direct(xs, shape, axes):
46+
def _check_shapes_for_direct(xs, shape, axes, check_complimentary=False):
4747
if len(axes) > 7: # Intel MKL supports up to 7D
4848
return False
49-
if not (len(xs) == len(shape)):
50-
# full-dimensional transform
49+
if not (len(xs) == len(shape)) and not check_complimentary:
50+
# full-dimensional transform is required for direct,
51+
# but less than full is OK for complimentary.
5152
return False
5253
if not (len(set(axes)) == len(axes)):
5354
# repeated axes
@@ -63,18 +64,6 @@ def _check_shapes_for_direct(xs, shape, axes):
6364
return True
6465

6566

66-
def _check_shapes_equiv_s_none(s, shape, axes):
67-
for si, ai in zip(s, axes):
68-
try:
69-
sh_ai = shape[ai]
70-
except IndexError:
71-
raise ValueError("Invalid axis (%d) specified" % ai)
72-
73-
if si != sh_ai:
74-
return False
75-
return True
76-
77-
7867
def _compute_fwd_scale(norm, n, shape):
7968
_check_norm(norm)
8069
if norm in (None, "backward"):
@@ -394,7 +383,7 @@ def _c2c_fftnd_impl(
394383
if direction not in [-1, +1]:
395384
raise ValueError("Direction of FFT should +1 or -1")
396385

397-
s_equiv_to_none = s is None
386+
_complementary = s is None
398387
valid_dtypes = [np.complex64, np.complex128, np.float32, np.float64]
399388
# _direct_fftnd requires complex type, and full-dimensional transform
400389
if isinstance(x, np.ndarray) and x.size != 0 and x.ndim > 1:
@@ -407,8 +396,10 @@ def _c2c_fftnd_impl(
407396
_direct = True
408397
# See if s matches the shape of x along the given axes.
409398
# If it does, we can use _iter_complementary rather than _iter_fftnd.
410-
if _check_shapes_equiv_s_none(xs, x.shape, xa):
411-
s_equiv_to_none = True
399+
if _check_shapes_for_direct(
400+
xs, x.shape, xa, check_complimentary=True
401+
):
402+
_complementary = True
412403
_direct = _direct and x.dtype in valid_dtypes
413404
else:
414405
_direct = False
@@ -421,7 +412,7 @@ def _c2c_fftnd_impl(
421412
out=out,
422413
)
423414
else:
424-
if s_equiv_to_none and x.dtype in valid_dtypes:
415+
if _complementary and x.dtype in valid_dtypes:
425416
x = np.asarray(x)
426417
if out is None:
427418
res = np.empty_like(x, dtype=_output_dtype(x.dtype))

0 commit comments

Comments
 (0)