@@ -43,11 +43,12 @@ def _check_norm(norm):
4343 )
4444
4545
46- def _check_shapes_for_direct (xs , shape , axes ):
46+ def _check_shapes_for_direct (xs , shape , axes , check_complimentary = False ):
4747 if len (axes ) > 7 : # Intel MKL supports up to 7D
4848 return False
49- if not (len (xs ) == len (shape )):
50- # full-dimensional transform
49+ if not (len (xs ) == len (shape )) and not check_complimentary :
50+ # full-dimensional transform is required for direct,
51+ # but less than full is OK for complimentary.
5152 return False
5253 if not (len (set (axes )) == len (axes )):
5354 # repeated axes
@@ -63,18 +64,6 @@ def _check_shapes_for_direct(xs, shape, axes):
6364 return True
6465
6566
66- def _check_shapes_equiv_s_none (s , shape , axes ):
67- for si , ai in zip (s , axes ):
68- try :
69- sh_ai = shape [ai ]
70- except IndexError :
71- raise ValueError ("Invalid axis (%d) specified" % ai )
72-
73- if si != sh_ai :
74- return False
75- return True
76-
77-
7867def _compute_fwd_scale (norm , n , shape ):
7968 _check_norm (norm )
8069 if norm in (None , "backward" ):
@@ -394,7 +383,7 @@ def _c2c_fftnd_impl(
394383 if direction not in [- 1 , + 1 ]:
395384 raise ValueError ("Direction of FFT should +1 or -1" )
396385
397- s_equiv_to_none = s is None
386+ _complementary = s is None
398387 valid_dtypes = [np .complex64 , np .complex128 , np .float32 , np .float64 ]
399388 # _direct_fftnd requires complex type, and full-dimensional transform
400389 if isinstance (x , np .ndarray ) and x .size != 0 and x .ndim > 1 :
@@ -407,8 +396,10 @@ def _c2c_fftnd_impl(
407396 _direct = True
408397 # See if s matches the shape of x along the given axes.
409398 # If it does, we can use _iter_complementary rather than _iter_fftnd.
410- if _check_shapes_equiv_s_none (xs , x .shape , xa ):
411- s_equiv_to_none = True
399+ if _check_shapes_for_direct (
400+ xs , x .shape , xa , check_complimentary = True
401+ ):
402+ _complementary = True
412403 _direct = _direct and x .dtype in valid_dtypes
413404 else :
414405 _direct = False
@@ -421,7 +412,7 @@ def _c2c_fftnd_impl(
421412 out = out ,
422413 )
423414 else :
424- if s_equiv_to_none and x .dtype in valid_dtypes :
415+ if _complementary and x .dtype in valid_dtypes :
425416 x = np .asarray (x )
426417 if out is None :
427418 res = np .empty_like (x , dtype = _output_dtype (x .dtype ))
0 commit comments