Skip to content

Commit bfcd222

Browse files
committed
adds tests for the compound assignments
1 parent fdb6d34 commit bfcd222

6 files changed

Lines changed: 142 additions & 53 deletions

File tree

stan/math/opencl/kernel_generator/as_operation_cl.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace math {
2424
* @param a an operation
2525
* @return operation
2626
*/
27-
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_operation,
27+
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_operation,
2828
typename = std::enable_if_t<std::is_base_of<
2929
operation_cl_base, std::remove_reference_t<T_operation>>::value>>
3030
inline T_operation&& as_operation_cl(T_operation&& a) {
@@ -38,7 +38,7 @@ inline T_operation&& as_operation_cl(T_operation&& a) {
3838
* @param a scalar
3939
* @return \c scalar_ wrapping the input
4040
*/
41-
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_scalar, typename = require_arithmetic_t<T_scalar>,
41+
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_scalar, typename = require_arithmetic_t<T_scalar>,
4242
require_not_same_t<T_scalar, bool>* = nullptr>
4343
inline scalar_<T_scalar> as_operation_cl(const T_scalar a) {
4444
return scalar_<T_scalar>(a);
@@ -51,7 +51,7 @@ inline scalar_<T_scalar> as_operation_cl(const T_scalar a) {
5151
* @param a scalar
5252
* @return \c scalar_<char> wrapping the input
5353
*/
54-
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals>
54+
template <assign_op_cl AssignOp = assign_op_cl::equals>
5555
inline scalar_<char> as_operation_cl(const bool a) { return scalar_<char>(a); }
5656

5757
/**
@@ -61,7 +61,7 @@ inline scalar_<char> as_operation_cl(const bool a) { return scalar_<char>(a); }
6161
* @param a \c matrix_cl
6262
* @return \c load_ wrapping the input
6363
*/
64-
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_matrix_cl,
64+
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_matrix_cl,
6565
typename = require_any_t<is_matrix_cl<T_matrix_cl>,
6666
is_arena_matrix_cl<T_matrix_cl>>>
6767
inline load_<T_matrix_cl, AssignOp> as_operation_cl(T_matrix_cl&& a) {
@@ -76,7 +76,7 @@ inline load_<T_matrix_cl, AssignOp> as_operation_cl(T_matrix_cl&& a) {
7676
* rvalue reference, the reference is removed, so that a variable of this type
7777
* actually stores the value.
7878
*/
79-
template <typename T, assignment_ops_cl AssignOp = assignment_ops_cl::equals>
79+
template <typename T, assign_op_cl AssignOp = assign_op_cl::equals>
8080
using as_operation_cl_t = std::conditional_t<
8181
std::is_lvalue_reference<T>::value,
8282
decltype(as_operation_cl<AssignOp>(std::declval<T>())),

stan/math/opencl/kernel_generator/assignment_ops.hpp

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,60 +9,56 @@ namespace math {
99
/**
1010
* Ops that decide the type of assignment for LHS operations
1111
*/
12-
enum class assignment_ops_cl {equals, plus_equals, minus_equals, divide_equals};
12+
enum class assign_op_cl {equals, plus_equals, minus_equals, divide_equals, multiply_equals};
1313

14+
namespace internal {
1415
/**
1516
* @param value A static constexpr const char* member for printing assignment ops
1617
*/
17-
template <assignment_ops_cl assign_op>
18-
struct assignment_op_str;
18+
template <assign_op_cl assign_op>
19+
struct assignment_op_str_impl;
1920

2021
template <>
21-
struct assignment_op_str<assignment_ops_cl::equals> {
22+
struct assignment_op_str_impl<assign_op_cl::equals> {
2223
static constexpr const char* value = " = ";
2324
};
2425

2526
template <>
26-
struct assignment_op_str<assignment_ops_cl::plus_equals> {
27+
struct assignment_op_str_impl<assign_op_cl::plus_equals> {
2728
static constexpr const char* value = " += ";
2829
};
2930

3031
template <>
31-
struct assignment_op_str<assignment_ops_cl::minus_equals> {
32-
static constexpr const char* value = " *= ";
32+
struct assignment_op_str_impl<assign_op_cl::minus_equals> {
33+
static constexpr const char* value = " -= ";
3334
};
3435

3536
template <>
36-
struct assignment_op_str<assignment_ops_cl::divide_equals> {
37+
struct assignment_op_str_impl<assign_op_cl::divide_equals> {
3738
static constexpr const char* value = " /= ";
3839
};
3940

41+
template <>
42+
struct assignment_op_str_impl<assign_op_cl::multiply_equals> {
43+
static constexpr const char* value = " *= ";
44+
};
4045

41-
namespace internal {
4246
template <typename, typename = void>
43-
struct has_assignment_op_str : std::false_type {};
47+
struct assignment_op_str : assignment_op_str_impl<assign_op_cl::equals> {};
4448

4549
template <typename T>
46-
struct has_assignment_op_str<T, void_t<decltype(T::assignment_op)>> : std::true_type {};
50+
struct assignment_op_str<T, void_t<decltype(T::assignment_op)>> : assignment_op_str_impl<T::assignment_op> {};
4751

4852
} // namespace internal
4953

50-
/**
51-
* @tparam T A type that does not have an `assignment_op` static constexpr member type
52-
* @return A constexpr const char* equal to `" = "`
53-
*/
54-
template <typename T, std::enable_if_t<!internal::has_assignment_op_str<std::decay_t<T>>::value>* = nullptr>
55-
inline constexpr const char* assignment_op() noexcept {
56-
return " = ";
57-
}
5854

5955
/**
6056
* @tparam T A type that has an `assignment_op` static constexpr member type
6157
* @return The types assignment op as a constexpr const char*
6258
*/
63-
template <typename T, std::enable_if_t<internal::has_assignment_op_str<T>::value>* = nullptr>
59+
template <typename T>
6460
inline constexpr const char* assignment_op() noexcept {
65-
return assignment_op_str<std::decay_t<T>::assignment_op>::value;
61+
return internal::assignment_op_str<std::decay_t<T>>::value;
6662
}
6763

6864
}

stan/math/opencl/kernel_generator/load.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace math {
2727
* @tparam T \c matrix_cl
2828
* @tparam AssignOp tells higher level operations whether the final operation should be an assignment or a type of compound assignment.
2929
*/
30-
template <typename T, assignment_ops_cl AssignOp = assignment_ops_cl::equals>
30+
template <typename T, assign_op_cl AssignOp = assign_op_cl::equals>
3131
class load_
3232
: public operation_cl_lhs<load_<T, AssignOp>,
3333
typename std::remove_reference_t<T>::type> {
@@ -36,7 +36,7 @@ class load_
3636

3737
public:
3838

39-
static constexpr assignment_ops_cl assignment_op = AssignOp;
39+
static constexpr assign_op_cl assignment_op = AssignOp;
4040
using Scalar = typename std::remove_reference_t<T>::type;
4141
using base = operation_cl<load_<T, AssignOp>, Scalar>;
4242
using base::var_name_;

stan/math/opencl/kernel_generator/multi_result_kernel.hpp

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -323,19 +323,19 @@ class results_cl {
323323
});
324324
}
325325

326-
/**
326+
/**
327327
* Incrementing \c results_ object by \c expressions_cl object
328328
* executes the kernel that evaluates expressions and increments results by
329329
* those expressions.
330330
* @tparam T_expressions types of expressions
331331
* @param exprs expressions
332332
*/
333-
template <typename... T_expressions,
333+
template <assign_op_cl AssignOp = assign_op_cl::plus_equals, typename... T_expressions,
334334
typename = std::enable_if_t<sizeof...(T_results)
335335
== sizeof...(T_expressions)>>
336-
void operator+=(const expressions_cl<T_expressions...>& exprs) {
336+
void compound_assignment_impl(const expressions_cl<T_expressions...>& exprs) {
337337
index_apply<sizeof...(T_expressions)>([this, &exprs](auto... Is) {
338-
auto tmp = std::tuple_cat(make_assignment_pair<assignment_ops_cl::plus_equals>(
338+
auto tmp = std::tuple_cat(make_assignment_pair<AssignOp>(
339339
std::get<Is>(results_), std::get<Is>(exprs.expressions_))...);
340340
index_apply<std::tuple_size<decltype(tmp)>::value>(
341341
[this, &tmp](auto... Is2) {
@@ -345,6 +345,20 @@ class results_cl {
345345
});
346346
}
347347

348+
/**
349+
* Incrementing \c results_ object by \c expressions_cl object
350+
* executes the kernel that evaluates expressions and increments results by
351+
* those expressions.
352+
* @tparam T_expressions types of expressions
353+
* @param exprs expressions
354+
*/
355+
template <typename... T_expressions,
356+
typename = std::enable_if_t<sizeof...(T_results)
357+
== sizeof...(T_expressions)>>
358+
void operator+=(const expressions_cl<T_expressions...>& exprs) {
359+
compound_assignment_impl<assign_op_cl::plus_equals>(exprs);
360+
}
361+
348362
/**
349363
* Incrementing \c results_ object by \c expressions_cl object
350364
* executes the kernel that evaluates expressions and increments results by
@@ -356,15 +370,35 @@ class results_cl {
356370
typename = std::enable_if_t<sizeof...(T_results)
357371
== sizeof...(T_expressions)>>
358372
void operator-=(const expressions_cl<T_expressions...>& exprs) {
359-
index_apply<sizeof...(T_expressions)>([this, &exprs](auto... Is) {
360-
auto tmp = std::tuple_cat(make_assignment_pair<assignment_ops_cl::minus_equals>(
361-
std::get<Is>(results_), std::get<Is>(exprs.expressions_))...);
362-
index_apply<std::tuple_size<decltype(tmp)>::value>(
363-
[this, &tmp](auto... Is2) {
364-
assignment_impl(std::make_tuple(std::make_pair(
365-
std::get<Is2>(tmp).first, std::get<Is2>(tmp).second)...));
366-
});
367-
});
373+
compound_assignment_impl<assign_op_cl::minus_equals>(exprs);
374+
}
375+
376+
/**
377+
* Incrementing \c results_ object by \c expressions_cl object
378+
* executes the kernel that evaluates expressions and increments results by
379+
* those expressions.
380+
* @tparam T_expressions types of expressions
381+
* @param exprs expressions
382+
*/
383+
template <typename... T_expressions,
384+
typename = std::enable_if_t<sizeof...(T_results)
385+
== sizeof...(T_expressions)>>
386+
void operator/=(const expressions_cl<T_expressions...>& exprs) {
387+
compound_assignment_impl<assign_op_cl::divide_equals>(exprs);
388+
}
389+
390+
/**
391+
* Incrementing \c results_ object by \c expressions_cl object
392+
* executes the kernel that evaluates expressions and increments results by
393+
* those expressions.
394+
* @tparam T_expressions types of expressions
395+
* @param exprs expressions
396+
*/
397+
template <typename... T_expressions,
398+
typename = std::enable_if_t<sizeof...(T_results)
399+
== sizeof...(T_expressions)>>
400+
void operator*=(const expressions_cl<T_expressions...>& exprs) {
401+
compound_assignment_impl<assign_op_cl::multiply_equals>(exprs);
368402
}
369403

370404
/**
@@ -551,7 +585,7 @@ class results_cl {
551585
* @param expression expression
552586
* @return a tuple of pair of result and expression
553587
*/
554-
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_result, typename T_expression,
588+
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_result, typename T_expression,
555589
require_all_not_t<is_without_output<T_expression>,
556590
conjunction<internal::is_scalar_check<T_result>,
557591
std::is_arithmetic<std::decay_t<
@@ -570,7 +604,7 @@ class results_cl {
570604
* @param expression expression
571605
* @return a tuple of pair of result and expression
572606
*/
573-
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_result, typename T_expression,
607+
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_result, typename T_expression,
574608
require_t<is_without_output<T_expression>>* = nullptr>
575609
static auto make_assignment_pair(T_result&& result,
576610
T_expression&& expression) {
@@ -584,7 +618,7 @@ class results_cl {
584618
* @param pass bool scalar
585619
* @return an empty tuple
586620
*/
587-
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_check, typename T_pass,
621+
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_check, typename T_pass,
588622
require_t<internal::is_scalar_check<T_check>>* = nullptr,
589623
require_integral_t<T_pass>* = nullptr>
590624
static std::tuple<> make_assignment_pair(T_check&& result, T_pass&& pass) {

stan/math/opencl/rev/adjoint_results.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class adjoint_results_cl : protected results_cl<T_results...> {
4141
index_apply<sizeof...(T_expressions)>([&](auto... Is) {
4242
auto scalars = std::tuple_cat(select_scalar_assignments(
4343
std::get<Is>(this->results_), std::get<Is>(exprs.expressions_))...);
44-
auto nonscalars_tmp = std::tuple_cat(select_nonscalar_plusequals(
44+
auto nonscalars_tmp = std::tuple_cat(select_nonscalar_assignments<assign_op_cl::plus_equals>(
4545
std::get<Is>(this->results_), std::get<Is>(exprs.expressions_))...);
4646

4747
index_apply<std::tuple_size<decltype(nonscalars_tmp)>::value>(
@@ -55,7 +55,7 @@ class adjoint_results_cl : protected results_cl<T_results...> {
5555
// evaluate all expressions
5656
this->assignment_impl(std::tuple_cat(
5757
nonscalars,
58-
this->template make_assignment_pair<assignment_ops_cl::plus_equals>(
58+
this->template make_assignment_pair<assign_op_cl::plus_equals>(
5959
std::get<2>(std::get<Is_scal>(scalars)),
6060
sum_2d(std::get<1>(std::get<Is_scal>(scalars))))...));
6161

@@ -109,12 +109,12 @@ class adjoint_results_cl : protected results_cl<T_results...> {
109109
* @return pair of result and expression or empty tuple (if the result is
110110
* check or the expression is `calc_if<false,T>`.
111111
*/
112-
template <typename T_result, typename T_expression,
112+
template <assign_op_cl AssignOp, typename T_result, typename T_expression,
113113
require_not_stan_scalar_t<T_result>* = nullptr,
114114
require_st_var<T_result>* = nullptr>
115-
auto select_nonscalar_plusequals(T_result&& result,
115+
auto select_nonscalar_assignments(T_result&& result,
116116
T_expression&& expression) {
117-
return results_cl<T_results...>::template make_assignment_pair<assignment_ops_cl::plus_equals>(
117+
return results_cl<T_results...>::template make_assignment_pair<AssignOp>(
118118
result.adj(), std::forward<T_expression>(expression));
119119
}
120120
/**
@@ -126,11 +126,11 @@ class adjoint_results_cl : protected results_cl<T_results...> {
126126
* @param expression expression
127127
* @return empty tuple
128128
*/
129-
template <
129+
template <assign_op_cl AssignOp,
130130
typename T_result, typename T_expression,
131131
std::enable_if_t<is_stan_scalar<T_result>::value
132132
|| !is_var<scalar_type_t<T_result>>::value>* = nullptr>
133-
auto select_nonscalar_plusequals(T_result&& result,
133+
auto select_nonscalar_assignments(T_result&& result,
134134
T_expression&& expression) {
135135
return std::make_tuple();
136136
}

test/unit/math/opencl/kernel_generator/assignment_ops_test.cpp

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,79 @@
88
#include <gtest/gtest.h>
99
#include <string>
1010

11-
TEST(KernelGenerator, assign_ops) {
11+
TEST(KernelGenerator, plus_equals) {
1212
using stan::math::matrix_cl;
1313
using stan::math::var_value;
1414
using stan::math::var;
1515
using stan::math::to_matrix_cl;
16+
using stan::math::from_matrix_cl;
1617
Eigen::MatrixXd A = Eigen::MatrixXd::Random(10, 10);
1718
Eigen::MatrixXd B = Eigen::MatrixXd::Random(10, 10);
1819
Eigen::MatrixXd C = Eigen::MatrixXd::Random(10, 10);
19-
C += A + B;
2020
matrix_cl<double> A_cl = to_matrix_cl(A);
2121
matrix_cl<double> B_cl = to_matrix_cl(B);
2222
matrix_cl<double> C_cl = to_matrix_cl(C);
23+
C += A + B;
2324
results(C_cl) += expressions(A_cl + B_cl);
25+
Eigen::MatrixXd C_cl_host = from_matrix_cl(C_cl);
26+
EXPECT_MATRIX_EQ(C_cl_host, C)
27+
}
28+
29+
TEST(KernelGenerator, minus_equals) {
30+
using stan::math::matrix_cl;
31+
using stan::math::var_value;
32+
using stan::math::var;
33+
using stan::math::to_matrix_cl;
34+
using stan::math::from_matrix_cl;
35+
Eigen::MatrixXd A = Eigen::MatrixXd::Random(10, 10);
36+
Eigen::MatrixXd B = Eigen::MatrixXd::Random(10, 10);
37+
Eigen::MatrixXd C = Eigen::MatrixXd::Random(10, 10);
38+
matrix_cl<double> A_cl = to_matrix_cl(A);
39+
matrix_cl<double> B_cl = to_matrix_cl(B);
40+
matrix_cl<double> C_cl = to_matrix_cl(C);
41+
C -= A + B;
42+
results(C_cl) -= expressions(A_cl + B_cl);
43+
Eigen::MatrixXd C_cl_host = from_matrix_cl(C_cl);
44+
EXPECT_MATRIX_EQ(C_cl_host, C)
45+
}
46+
47+
TEST(KernelGenerator, divide_equals) {
48+
using stan::math::matrix_cl;
49+
using stan::math::var_value;
50+
using stan::math::var;
51+
using stan::math::to_matrix_cl;
52+
using stan::math::from_matrix_cl;
53+
Eigen::MatrixXd A = Eigen::MatrixXd::Random(10, 10);
54+
Eigen::MatrixXd B = Eigen::MatrixXd::Random(10, 10);
55+
Eigen::MatrixXd C = Eigen::MatrixXd::Random(10, 10);
56+
matrix_cl<double> A_cl = to_matrix_cl(A);
57+
matrix_cl<double> B_cl = to_matrix_cl(B);
58+
matrix_cl<double> C_cl = to_matrix_cl(C);
59+
C.array() /= A.array() + B.array();
60+
results(C_cl) /= expressions(A_cl + B_cl);
61+
Eigen::MatrixXd C_cl_host = from_matrix_cl(C_cl);
62+
EXPECT_MATRIX_EQ(C_cl_host, C)
63+
}
64+
65+
TEST(KernelGenerator, times_equals) {
66+
using stan::math::matrix_cl;
67+
using stan::math::var_value;
68+
using stan::math::var;
69+
using stan::math::to_matrix_cl;
70+
using stan::math::from_matrix_cl;
71+
72+
Eigen::MatrixXd A = Eigen::MatrixXd::Random(10, 10);
73+
Eigen::MatrixXd B = Eigen::MatrixXd::Random(10, 10);
74+
Eigen::MatrixXd C = Eigen::MatrixXd::Random(10, 10);
75+
matrix_cl<double> A_cl = to_matrix_cl(A);
76+
matrix_cl<double> B_cl = to_matrix_cl(B);
77+
matrix_cl<double> C_cl = to_matrix_cl(C);
78+
C.array() *= A.array() + B.array();
79+
results(C_cl) *= expressions(A_cl + B_cl);
80+
Eigen::MatrixXd C_cl_host = from_matrix_cl(C_cl);
81+
EXPECT_MATRIX_EQ(C_cl_host, C)
2482
}
2583

2684

85+
2786
#endif

0 commit comments

Comments
 (0)