Skip to content

Commit 26fc041

Browse files
committed
Fix bug in outerproducts for tensormaps
1 parent 5e51bca commit 26fc041

2 files changed

Lines changed: 19 additions & 2 deletions

File tree

Fastor/tensor_algebra/outerproduct.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ dyadic(const AbstractTensor<Derived0,DIM0> &a, const AbstractTensor<Derived1,DIM
9797
// return a;
9898
// }
9999

100-
template<typename AbstractTensorType0, typename AbstractTensorType1, typename ... AbstractTensorTypes>
100+
template<typename AbstractTensorType0, typename AbstractTensorType1, typename ... AbstractTensorTypes,
101+
enable_if_t_<is_greater_equal_v_<sizeof...(AbstractTensorTypes),1>,bool> >
101102
FASTOR_INLINE
102103
auto
103104
outer(const AbstractTensorType0& a, const AbstractTensorType1& b, const AbstractTensorTypes& ... rest)
@@ -106,7 +107,8 @@ outer(const AbstractTensorType0& a, const AbstractTensorType1& b, const Abstract
106107
return outer(res, rest...);
107108
}
108109

109-
template<typename AbstractTensorType0, typename AbstractTensorType1, typename ... AbstractTensorTypes>
110+
template<typename AbstractTensorType0, typename AbstractTensorType1, typename ... AbstractTensorTypes,
111+
enable_if_t_<is_greater_equal_v_<sizeof...(AbstractTensorTypes),1>,bool> >
110112
FASTOR_INLINE
111113
auto
112114
dyadic(const AbstractTensorType0& a, const AbstractTensorType1& b, const AbstractTensorTypes& ... rest)

tests/test_tensormap/test_tensormap.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,21 @@ void run() {
8888
FASTOR_EXIT_ASSERT(std::abs(res1.sum() + 28 ) < Tol);
8989
}
9090

91+
// Bug 132 - add 1D tensormaps, outer tensormaps
92+
{
93+
std::array<double,3> a = {1,1,1};
94+
TensorMap<double,3> at(a.data());
95+
96+
std::array<double,3> b = {2,2,2};
97+
TensorMap<double,3> bt(b.data());
98+
99+
Tensor<double,3> ct = bt + at;
100+
FASTOR_EXIT_ASSERT(std::abs(ct.sum() - 9 ) < Tol);
101+
102+
Tensor<double,3,3> dt = outer(at, bt);
103+
FASTOR_EXIT_ASSERT(std::abs(dt.sum() - 18 ) < Tol);
104+
}
105+
91106
print(FGRN(BOLD("All tests passed successfully")));
92107
}
93108

0 commit comments

Comments
 (0)