4545
4646# pylint: disable=no-name-in-module
4747from .dpnp_algo import (
48- dpnp_inner ,
4948 dpnp_kron ,
5049)
5150from .dpnp_utils import (
@@ -218,43 +217,92 @@ def einsum_path(*args, **kwargs):
218217 return call_origin (numpy .einsum_path , * args , ** kwargs )
219218
220219
221- def inner (x1 , x2 , ** kwargs ):
220+ def inner (a , b ):
222221 """
223222 Returns the inner product of two arrays.
224223
225224 For full documentation refer to :obj:`numpy.inner`.
226225
227- Limitations
228- -----------
229- Parameters `x1` and `x2` are supported as :obj:`dpnp.ndarray`.
230- Keyword argument `kwargs` is currently unsupported.
231- Otherwise the functions will be executed sequentially on CPU.
226+ Parameters
227+ ----------
228+ a : {dpnp.ndarray, usm_ndarray, scalar}
229+ First input array. Both inputs `a` and `b` can not be scalars
230+ at the same time.
231+ b : {dpnp.ndarray, usm_ndarray, scalar}
232+ Second input array. Both inputs `a` and `b` can not be scalars
233+ at the same time.
234+
235+ Returns
236+ -------
237+ out : dpnp.ndarray
238+ If either `a` or `b` is a scalar, the shape of the returned arrays
239+ matches that of the array between `a` and `b`, whichever is an array.
240+ If `a` and `b` are both 1-D arrays then a 0-d array is returned;
241+ otherwise an array with a shape as
242+ ``out.shape = (*a.shape[:-1], *b.shape[:-1])`` is returned.
243+
232244
233245 See Also
234246 --------
235- :obj:`dpnp.einsum` : Evaluates the Einstein summation convention
236- on the operands.
237- :obj:`dpnp.dot` : Returns the dot product of two arrays.
238- :obj:`dpnp.tensordot` : Compute tensor dot product along specified axes.
239- Input array data types are limited by supported DPNP :ref:`Data types`.
247+ :obj:`dpnp.einsum` : Einstein summation convention..
248+ :obj:`dpnp.dot` : Generalised matrix product,
249+ using second last dimension of `b`.
250+ :obj:`dpnp.tensordot` : Sum products over arbitrary axes.
240251
241252 Examples
242253 --------
254+ # Ordinary inner product for vectors
255+
243256 >>> import dpnp as np
244- >>> a = np.array([1,2, 3])
257+ >>> a = np.array([1, 2, 3])
245258 >>> b = np.array([0, 1, 0])
246- >>> result = np.inner(a, b)
247- >>> [x for x in result]
248- [2]
259+ >>> np.inner(a, b)
260+ array(2)
261+
262+ # Some multidimensional examples
263+
264+ >>> a = np.arange(24).reshape((2,3,4))
265+ >>> b = np.arange(4)
266+ >>> c = np.inner(a, b)
267+ >>> c.shape
268+ (2, 3)
269+ >>> c
270+ array([[ 14, 38, 62],
271+ [86, 110, 134]])
272+
273+ >>> a = np.arange(2).reshape((1,1,2))
274+ >>> b = np.arange(6).reshape((3,2))
275+ >>> c = np.inner(a, b)
276+ >>> c.shape
277+ (1, 1, 3)
278+ >>> c
279+ array([[[1, 3, 5]]])
280+
281+ An example where `b` is a scalar
282+
283+ >>> np.inner(np.eye(2), 7)
284+ array([[7., 0.],
285+ [0., 7.]])
249286
250287 """
251288
252- x1_desc = dpnp .get_dpnp_descriptor (x1 , copy_when_nondefault_queue = False )
253- x2_desc = dpnp .get_dpnp_descriptor (x2 , copy_when_nondefault_queue = False )
254- if x1_desc and x2_desc and not kwargs :
255- return dpnp_inner (x1_desc , x2_desc ).get_pyobj ()
289+ dpnp .check_supported_arrays_type (a , b , scalar_type = True )
290+
291+ if dpnp .isscalar (a ) or dpnp .isscalar (b ):
292+ return dpnp .multiply (a , b )
293+
294+ if a .ndim == 0 or b .ndim == 0 :
295+ return dpnp .multiply (a , b )
296+
297+ if a .shape [- 1 ] != b .shape [- 1 ]:
298+ raise ValueError (
299+ "shape of input arrays is not similar at the last axis."
300+ )
301+
302+ if a .ndim == 1 and b .ndim == 1 :
303+ return dpnp_dot (a , b )
256304
257- return call_origin ( numpy . inner , x1 , x2 , ** kwargs )
305+ return dpnp . tensordot ( a , b , axes = ( - 1 , - 1 ) )
258306
259307
260308def kron (x1 , x2 ):
@@ -567,16 +615,20 @@ def tensordot(a, b, axes=2):
567615
568616 dpnp .check_supported_arrays_type (a , b , scalar_type = True )
569617
570- if dpnp .isscalar (a ):
571- a = dpnp .array (a , sycl_queue = b .sycl_queue , usm_type = b .usm_type )
572- elif dpnp .isscalar (b ):
573- b = dpnp .array (b , sycl_queue = a .sycl_queue , usm_type = a .usm_type )
618+ if dpnp .isscalar (a ) or dpnp .isscalar (b ):
619+ if not isinstance (axes , int ) or axes != 0 :
620+ raise ValueError (
621+ "One of the inputs is scalar, axes should be zero."
622+ )
623+ return dpnp .multiply (a , b )
574624
575625 try :
576626 iter (axes )
577627 except Exception as e : # pylint: disable=broad-exception-caught
578628 if not isinstance (axes , int ):
579629 raise TypeError ("Axes must be an integer." ) from e
630+ if axes < 0 :
631+ raise ValueError ("Axes must be a nonnegative integer." ) from e
580632 axes_a = tuple (range (- axes , 0 ))
581633 axes_b = tuple (range (0 , axes ))
582634 else :
@@ -590,6 +642,15 @@ def tensordot(a, b, axes=2):
590642 if len (axes_a ) != len (axes_b ):
591643 raise ValueError ("Axes length mismatch." )
592644
645+ # Make the axes non-negative
646+ a_ndim = a .ndim
647+ b_ndim = b .ndim
648+ axes_a = normalize_axis_tuple (axes_a , a_ndim , "axis_a" )
649+ axes_b = normalize_axis_tuple (axes_b , b_ndim , "axis_b" )
650+
651+ if a .ndim == 0 or b .ndim == 0 :
652+ return dpnp .multiply (a , b )
653+
593654 a_shape = a .shape
594655 b_shape = b .shape
595656 for axis_a , axis_b in zip (axes_a , axes_b ):
@@ -598,12 +659,6 @@ def tensordot(a, b, axes=2):
598659 "shape of input arrays is not similar at requested axes."
599660 )
600661
601- # Make the axes non-negative
602- a_ndim = a .ndim
603- b_ndim = b .ndim
604- axes_a = normalize_axis_tuple (axes_a , a_ndim , "axis" )
605- axes_b = normalize_axis_tuple (axes_b , b_ndim , "axis" )
606-
607662 # Move the axes to sum over, to the end of "a"
608663 notin = tuple (k for k in range (a_ndim ) if k not in axes_a )
609664 newaxes_a = notin + axes_a
0 commit comments