@@ -36,13 +36,24 @@ template<typename T, size_t N>
3636struct scalar_type_finder <Tensor<T,N>> {
3737 using type = T;
3838};
39+
40+ template <typename T, size_t ... Rest>
41+ struct scalar_type_finder <TensorMap<T,Rest...>> {
42+ using type = T;
43+ };
44+ // This specific specialisation is needed to avoid ambiguity for vectors
45+ template <typename T, size_t N>
46+ struct scalar_type_finder <TensorMap<T,N>> {
47+ using type = T;
48+ };
49+
3950template <template <class ,size_t > class UnaryExpr , typename Expr, size_t DIMS>
4051struct scalar_type_finder <UnaryExpr<Expr,DIMS>> {
4152 using type = typename scalar_type_finder<Expr>::type;
4253};
4354template <template <class ,class ,size_t > class Expr , typename TLhs, typename TRhs, size_t DIMS>
4455struct scalar_type_finder <Expr<TLhs,TRhs,DIMS>> {
45- using type = typename std::conditional<std::is_arithmetic <TLhs>::value ,
56+ using type = typename std::conditional<is_primitive_v_ <TLhs>,
4657 typename scalar_type_finder<TRhs>::type, typename scalar_type_finder<TLhs>::type>::type;
4758};
4859template <template <typename ,typename ,typename ,size_t > class TensorFixedViewExpr ,
@@ -58,11 +69,6 @@ template<template<typename,size_t...> class TensorType, typename T, size_t ...Re
5869struct scalar_type_finder <TensorFixedViewExprnD<TensorType<T,Rest...>,Fseqs...>> {
5970 using type = T;
6071};
61-
62- template <typename T, size_t ... Rest>
63- struct scalar_type_finder <TensorMap<T,Rest...>> {
64- using type = T;
65- };
6672// --------------------------------------------------------------------------------------------------------------------//
6773
6874
@@ -82,14 +88,24 @@ template<typename T, size_t N>
8288struct tensor_type_finder <Tensor<T,N>> {
8389 using type = Tensor<T,N>;
8490};
91+
92+ template <typename T, size_t ... Rest>
93+ struct tensor_type_finder <TensorMap<T,Rest...>> {
94+ using type = Tensor<T,Rest...>;
95+ };
96+ // This specific specialisation is needed to avoid ambiguity for vectors
97+ template <typename T, size_t N>
98+ struct tensor_type_finder <TensorMap<T,N>> {
99+ using type = Tensor<T,N>;
100+ };
101+
85102template <template <typename ,size_t > class UnaryExpr , typename Expr, size_t DIM>
86103struct tensor_type_finder <UnaryExpr<Expr,DIM>> {
87104 using type = typename tensor_type_finder<Expr>::type;
88105};
89106template <template <class ,class ,size_t > class BinaryExpr , typename TLhs, typename TRhs, size_t DIMS>
90107struct tensor_type_finder <BinaryExpr<TLhs,TRhs,DIMS>> {
91- // using type = typename tensor_type_finder<TLhs>::type;
92- using type = typename std::conditional<std::is_arithmetic<TLhs>::value,
108+ using type = typename std::conditional<is_primitive_v_<TLhs>,
93109 typename tensor_type_finder<TRhs>::type, typename tensor_type_finder<TLhs>::type>::type;
94110};
95111template <template <typename ,typename ,typename ,size_t > class TensorFixedViewExpr ,
@@ -105,11 +121,6 @@ template<template<typename,size_t...> class TensorType, typename T, size_t ...Re
105121struct tensor_type_finder <TensorFixedViewExprnD<TensorType<T,Rest...>,Fseqs...>> {
106122 using type = TensorType<T,Rest...>;
107123};
108-
109- template <typename T, size_t ... Rest>
110- struct tensor_type_finder <TensorMap<T,Rest...>> {
111- using type = Tensor<T,Rest...>;
112- };
113124// --------------------------------------------------------------------------------------------------------------------//
114125
115126
0 commit comments