55#include < stan/math/prim/err/check_size_match.hpp>
66#include < stan/math/prim/meta/is_kernel_expression.hpp>
77#include < stan/math/opencl/kernel_generator/name_generator.hpp>
8+ #include < stan/math/opencl/kernel_generator/assignment_ops.hpp>
89#include < stan/math/opencl/kernel_generator/as_operation_cl.hpp>
910#include < stan/math/opencl/kernel_generator/calc_if.hpp>
1011#include < stan/math/opencl/kernel_generator/check_cl.hpp>
@@ -334,13 +335,34 @@ class results_cl {
334335 == sizeof ...(T_expressions)>>
335336 void operator +=(const expressions_cl<T_expressions...>& exprs) {
336337 index_apply<sizeof ...(T_expressions)>([this , &exprs](auto ... Is) {
337- auto tmp = std::tuple_cat (make_assignment_pair (
338+ auto tmp = std::tuple_cat (make_assignment_pair<assignment_ops_cl::plus_equals> (
338339 std::get<Is>(results_), std::get<Is>(exprs.expressions_ ))...);
339340 index_apply<std::tuple_size<decltype (tmp)>::value>(
340341 [this , &tmp](auto ... Is2) {
341342 assignment_impl (std::make_tuple (std::make_pair (
342- std::get<Is2>(tmp).first ,
343- std::get<Is2>(tmp).first + std::get<Is2>(tmp).second )...));
343+ std::get<Is2>(tmp).first , std::get<Is2>(tmp).second )...));
344+ });
345+ });
346+ }
347+
348+ /* *
349+ * Incrementing \c results_ object by \c expressions_cl object
350+ * executes the kernel that evaluates expressions and increments results by
351+ * those expressions.
352+ * @tparam T_expressions types of expressions
353+ * @param exprs expressions
354+ */
355+ template <typename ... T_expressions,
356+ typename = std::enable_if_t <sizeof ...(T_results)
357+ == sizeof ...(T_expressions)>>
358+ void operator -=(const expressions_cl<T_expressions...>& exprs) {
359+ index_apply<sizeof ...(T_expressions)>([this , &exprs](auto ... Is) {
360+ auto tmp = std::tuple_cat (make_assignment_pair<assignment_ops_cl::minus_equals>(
361+ std::get<Is>(results_), std::get<Is>(exprs.expressions_ ))...);
362+ index_apply<std::tuple_size<decltype (tmp)>::value>(
363+ [this , &tmp](auto ... Is2) {
364+ assignment_impl (std::make_tuple (std::make_pair (
365+ std::get<Is2>(tmp).first , std::get<Is2>(tmp).second )...));
344366 });
345367 });
346368 }
@@ -426,7 +448,7 @@ class results_cl {
426448 + parts.reduction_2d +
427449 " }\n " ;
428450 }
429- return src;
451+ return src;
430452 }
431453
432454 /* *
@@ -529,16 +551,16 @@ class results_cl {
529551 * @param expression expression
530552 * @return a tuple of pair of result and expression
531553 */
532- template <typename T_result, typename T_expression,
554+ template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_result, typename T_expression,
533555 require_all_not_t <is_without_output<T_expression>,
534556 conjunction<internal::is_scalar_check<T_result>,
535557 std::is_arithmetic<std::decay_t <
536558 T_expression>>>>* = nullptr >
537559 static auto make_assignment_pair (T_result&& result,
538560 T_expression&& expression) {
539561 return std::make_tuple (
540- std::pair<as_operation_cl_t <T_result>, as_operation_cl_t <T_expression>>(
541- as_operation_cl (std::forward<T_result>(result)),
562+ std::pair<as_operation_cl_t <T_result, AssignOp >, as_operation_cl_t <T_expression>>(
563+ as_operation_cl<AssignOp> (std::forward<T_result>(result)),
542564 as_operation_cl (std::forward<T_expression>(expression))));
543565 }
544566
@@ -548,7 +570,7 @@ class results_cl {
548570 * @param expression expression
549571 * @return a tuple of pair of result and expression
550572 */
551- template <typename T_result, typename T_expression,
573+ template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_result, typename T_expression,
552574 require_t <is_without_output<T_expression>>* = nullptr >
553575 static auto make_assignment_pair (T_result&& result,
554576 T_expression&& expression) {
@@ -562,7 +584,7 @@ class results_cl {
562584 * @param pass bool scalar
563585 * @return an empty tuple
564586 */
565- template <typename T_check, typename T_pass,
587+ template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_check, typename T_pass,
566588 require_t <internal::is_scalar_check<T_check>>* = nullptr ,
567589 require_integral_t <T_pass>* = nullptr >
568590 static std::tuple<> make_assignment_pair (T_check&& result, T_pass&& pass) {
0 commit comments