Skip to content

Commit c16efc5

Browse files
committed
Fix ambiguity in scalar_type_finder for tensor maps
1 parent aefce47 commit c16efc5

1 file changed

Lines changed: 24 additions & 13 deletions

File tree

Fastor/tensor/TensorTraits.h

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,24 @@ template<typename T, size_t N>
3636
struct scalar_type_finder<Tensor<T,N>> {
3737
using type = T;
3838
};
39+
40+
template<typename T, size_t ... Rest>
41+
struct scalar_type_finder<TensorMap<T,Rest...>> {
42+
using type = T;
43+
};
44+
// This specific specialisation is needed to avoid ambiguity for vectors
45+
template<typename T, size_t N>
46+
struct scalar_type_finder<TensorMap<T,N>> {
47+
using type = T;
48+
};
49+
3950
template<template <class,size_t> class UnaryExpr, typename Expr, size_t DIMS>
4051
struct scalar_type_finder<UnaryExpr<Expr,DIMS>> {
4152
using type = typename scalar_type_finder<Expr>::type;
4253
};
4354
template<template <class,class,size_t> class Expr, typename TLhs, typename TRhs, size_t DIMS>
4455
struct scalar_type_finder<Expr<TLhs,TRhs,DIMS>> {
45-
using type = typename std::conditional<std::is_arithmetic<TLhs>::value,
56+
using type = typename std::conditional<is_primitive_v_<TLhs>,
4657
typename scalar_type_finder<TRhs>::type, typename scalar_type_finder<TLhs>::type>::type;
4758
};
4859
template<template<typename,typename,typename,size_t> class TensorFixedViewExpr,
@@ -58,11 +69,6 @@ template<template<typename,size_t...> class TensorType, typename T, size_t ...Re
5869
struct scalar_type_finder<TensorFixedViewExprnD<TensorType<T,Rest...>,Fseqs...>> {
5970
using type = T;
6071
};
61-
62-
template<typename T, size_t ... Rest>
63-
struct scalar_type_finder<TensorMap<T,Rest...>> {
64-
using type = T;
65-
};
6672
//--------------------------------------------------------------------------------------------------------------------//
6773

6874

@@ -82,14 +88,24 @@ template<typename T, size_t N>
8288
struct tensor_type_finder<Tensor<T,N>> {
8389
using type = Tensor<T,N>;
8490
};
91+
92+
template<typename T, size_t ... Rest>
93+
struct tensor_type_finder<TensorMap<T,Rest...>> {
94+
using type = Tensor<T,Rest...>;
95+
};
96+
// This specific specialisation is needed to avoid ambiguity for vectors
97+
template<typename T, size_t N>
98+
struct tensor_type_finder<TensorMap<T,N>> {
99+
using type = Tensor<T,N>;
100+
};
101+
85102
template<template<typename,size_t> class UnaryExpr, typename Expr, size_t DIM>
86103
struct tensor_type_finder<UnaryExpr<Expr,DIM>> {
87104
using type = typename tensor_type_finder<Expr>::type;
88105
};
89106
template<template<class,class,size_t> class BinaryExpr, typename TLhs, typename TRhs, size_t DIMS>
90107
struct tensor_type_finder<BinaryExpr<TLhs,TRhs,DIMS>> {
91-
// using type = typename tensor_type_finder<TLhs>::type;
92-
using type = typename std::conditional<std::is_arithmetic<TLhs>::value,
108+
using type = typename std::conditional<is_primitive_v_<TLhs>,
93109
typename tensor_type_finder<TRhs>::type, typename tensor_type_finder<TLhs>::type>::type;
94110
};
95111
template<template<typename,typename,typename,size_t> class TensorFixedViewExpr,
@@ -105,11 +121,6 @@ template<template<typename,size_t...> class TensorType, typename T, size_t ...Re
105121
struct tensor_type_finder<TensorFixedViewExprnD<TensorType<T,Rest...>,Fseqs...>> {
106122
using type = TensorType<T,Rest...>;
107123
};
108-
109-
template<typename T, size_t ... Rest>
110-
struct tensor_type_finder<TensorMap<T,Rest...>> {
111-
using type = Tensor<T,Rest...>;
112-
};
113124
//--------------------------------------------------------------------------------------------------------------------//
114125

115126

0 commit comments

Comments
 (0)