22#define STAN_MATH_OPENCL_KERNEL_GENERATOR_AS_OPERATION_CL_HPP
33#ifdef STAN_OPENCL
44
5+ #include < stan/math/opencl/kernel_generator/assignment_ops.hpp>
56#include < stan/math/opencl/kernel_generator/operation_cl.hpp>
67#include < stan/math/opencl/kernel_generator/load.hpp>
78#include < stan/math/opencl/kernel_generator/scalar.hpp>
@@ -19,11 +20,12 @@ namespace math {
1920/* *
2021 * Converts any valid kernel generator expression into an operation. This is an
2122 * overload for operations - a no-op
23+ * @tparam AssignOp ignored
2224 * @tparam T_operation type of the input operation
2325 * @param a an operation
2426 * @return operation
2527 */
26- template <typename T_operation,
28+ template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_operation,
2729 typename = std::enable_if_t <std::is_base_of<
2830 operation_cl_base, std::remove_reference_t <T_operation>>::value>>
2931inline T_operation&& as_operation_cl(T_operation&& a) {
@@ -33,11 +35,13 @@ inline T_operation&& as_operation_cl(T_operation&& a) {
3335/* *
3436 * Converts any valid kernel generator expression into an operation. This is an
3537 * overload for scalars (arithmetic types). It wraps them into \c scalar_.
38+ * @tparam AssignOp ignored
3639 * @tparam T_scalar type of the input scalar
3740 * @param a scalar
3841 * @return \c scalar_ wrapping the input
3942 */
40- template <typename T_scalar, typename = require_arithmetic_t <T_scalar>,
43+ template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_scalar,
44+ typename = require_arithmetic_t <T_scalar>,
4145 require_not_same_t <T_scalar, bool >* = nullptr >
4246inline scalar_<T_scalar> as_operation_cl (const T_scalar a) {
4347 return scalar_<T_scalar>(a);
@@ -47,23 +51,29 @@ inline scalar_<T_scalar> as_operation_cl(const T_scalar a) {
4751 * Converts any valid kernel generator expression into an operation. This is an
4852 * overload for bool scalars. It wraps them into \c scalar_<char> as \c bool can
4953 * not be used as a type of a kernel argument.
54+ * @tparam AssignOp ignored
5055 * @param a scalar
5156 * @return \c scalar_<char> wrapping the input
5257 */
53- inline scalar_<char > as_operation_cl (const bool a) { return scalar_<char >(a); }
58+ template <assign_op_cl AssignOp = assign_op_cl::equals>
59+ inline scalar_<char > as_operation_cl (const bool a) {
60+ return scalar_<char >(a);
61+ }
5462
5563/* *
5664 * Converts any valid kernel generator expression into an operation. This is an
5765 * overload for \c matrix_cl. It wraps them into into \c load_.
66+ * @tparam AssignOp an optional `assign_op_cl` that dictates whether the object
67+ * is assigned using standard or compound assign.
5868 * @tparam T_matrix_cl \c matrix_cl
5969 * @param a \c matrix_cl
6070 * @return \c load_ wrapping the input
6171 */
62- template <typename T_matrix_cl,
72+ template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_matrix_cl,
6373 typename = require_any_t <is_matrix_cl<T_matrix_cl>,
6474 is_arena_matrix_cl<T_matrix_cl>>>
65- inline load_<T_matrix_cl> as_operation_cl (T_matrix_cl&& a) {
66- return load_<T_matrix_cl>(std::forward<T_matrix_cl>(a));
75+ inline load_<T_matrix_cl, AssignOp > as_operation_cl (T_matrix_cl&& a) {
76+ return load_<T_matrix_cl, AssignOp >(std::forward<T_matrix_cl>(a));
6777}
6878
6979/* *
@@ -73,12 +83,16 @@ inline load_<T_matrix_cl> as_operation_cl(T_matrix_cl&& a) {
7383 * as_operation_cl_t<T>. If the return value of \c as_operation_cl() would be a
7484 * rvalue reference, the reference is removed, so that a variable of this type
7585 * actually stores the value.
86+ * @tparam T a `matrix_cl` or `Scalar` type
87+ * @tparam AssignOp an optional `assign_op_cl` that dictates whether the object
88+ * is assigned using standard or compound assign.
7689 */
77- template <typename T>
78- using as_operation_cl_t = std::conditional_t <
79- std::is_lvalue_reference<T>::value,
80- decltype (as_operation_cl(std::declval<T>())),
81- std::remove_reference_t <decltype (as_operation_cl(std::declval<T>()))>>;
90+ template <typename T, assign_op_cl AssignOp = assign_op_cl::equals>
91+ using as_operation_cl_t
92+ = std::conditional_t <std::is_lvalue_reference<T>::value,
93+ decltype (as_operation_cl<AssignOp>(std::declval<T>())),
94+ std::remove_reference_t <decltype (
95+ as_operation_cl<AssignOp>(std::declval<T>()))>>;
8296
8397/* * @}*/
8498} // namespace math
0 commit comments