Skip to content

Commit be485f4

Browse files
authored
Merge pull request #2943 from stan-dev/fix/plusequals-assign-opencl
Fix Aliasing issue in OpenCL
2 parents eb3b5d7 + e027860 commit be485f4

13 files changed

Lines changed: 424 additions & 48 deletions

File tree

stan/math/opencl/kernel_generator.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
109109
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
110110
#include <stan/math/opencl/kernel_generator/type_str.hpp>
111-
111+
#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
112112
#include <stan/math/opencl/kernel_generator/as_column_vector_or_scalar.hpp>
113113
#include <stan/math/opencl/kernel_generator/load.hpp>
114114
#include <stan/math/opencl/kernel_generator/scalar.hpp>

stan/math/opencl/kernel_generator/as_operation_cl.hpp

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_AS_OPERATION_CL_HPP
33
#ifdef STAN_OPENCL
44

5+
#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
56
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
67
#include <stan/math/opencl/kernel_generator/load.hpp>
78
#include <stan/math/opencl/kernel_generator/scalar.hpp>
@@ -19,11 +20,12 @@ namespace math {
1920
/**
2021
* Converts any valid kernel generator expression into an operation. This is an
2122
* overload for operations - a no-op
23+
* @tparam AssignOp ignored
2224
* @tparam T_operation type of the input operation
2325
* @param a an operation
2426
* @return operation
2527
*/
26-
template <typename T_operation,
28+
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_operation,
2729
typename = std::enable_if_t<std::is_base_of<
2830
operation_cl_base, std::remove_reference_t<T_operation>>::value>>
2931
inline T_operation&& as_operation_cl(T_operation&& a) {
@@ -33,11 +35,13 @@ inline T_operation&& as_operation_cl(T_operation&& a) {
3335
/**
3436
* Converts any valid kernel generator expression into an operation. This is an
3537
* overload for scalars (arithmetic types). It wraps them into \c scalar_.
38+
* @tparam AssignOp ignored
3639
* @tparam T_scalar type of the input scalar
3740
* @param a scalar
3841
* @return \c scalar_ wrapping the input
3942
*/
40-
template <typename T_scalar, typename = require_arithmetic_t<T_scalar>,
43+
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_scalar,
44+
typename = require_arithmetic_t<T_scalar>,
4145
require_not_same_t<T_scalar, bool>* = nullptr>
4246
inline scalar_<T_scalar> as_operation_cl(const T_scalar a) {
4347
return scalar_<T_scalar>(a);
@@ -47,23 +51,29 @@ inline scalar_<T_scalar> as_operation_cl(const T_scalar a) {
4751
* Converts any valid kernel generator expression into an operation. This is an
4852
* overload for bool scalars. It wraps them into \c scalar_<char> as \c bool can
4953
* not be used as a type of a kernel argument.
54+
* @tparam AssignOp ignored
5055
* @param a scalar
5156
* @return \c scalar_<char> wrapping the input
5257
*/
53-
inline scalar_<char> as_operation_cl(const bool a) { return scalar_<char>(a); }
58+
template <assign_op_cl AssignOp = assign_op_cl::equals>
59+
inline scalar_<char> as_operation_cl(const bool a) {
60+
return scalar_<char>(a);
61+
}
5462

5563
/**
5664
* Converts any valid kernel generator expression into an operation. This is an
5765
* overload for \c matrix_cl. It wraps them into into \c load_.
66+
* @tparam AssignOp an optional `assign_op_cl` that dictates whether the object
67+
* is assigned using standard or compound assign.
5868
* @tparam T_matrix_cl \c matrix_cl
5969
* @param a \c matrix_cl
6070
* @return \c load_ wrapping the input
6171
*/
62-
template <typename T_matrix_cl,
72+
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_matrix_cl,
6373
typename = require_any_t<is_matrix_cl<T_matrix_cl>,
6474
is_arena_matrix_cl<T_matrix_cl>>>
65-
inline load_<T_matrix_cl> as_operation_cl(T_matrix_cl&& a) {
66-
return load_<T_matrix_cl>(std::forward<T_matrix_cl>(a));
75+
inline load_<T_matrix_cl, AssignOp> as_operation_cl(T_matrix_cl&& a) {
76+
return load_<T_matrix_cl, AssignOp>(std::forward<T_matrix_cl>(a));
6777
}
6878

6979
/**
@@ -73,12 +83,16 @@ inline load_<T_matrix_cl> as_operation_cl(T_matrix_cl&& a) {
7383
* as_operation_cl_t<T>. If the return value of \c as_operation_cl() would be a
7484
* rvalue reference, the reference is removed, so that a variable of this type
7585
* actually stores the value.
86+
* @tparam T a `matrix_cl` or `Scalar` type
87+
* @tparam AssignOp an optional `assign_op_cl` that dictates whether the object
88+
* is assigned using standard or compound assign.
7689
*/
77-
template <typename T>
78-
using as_operation_cl_t = std::conditional_t<
79-
std::is_lvalue_reference<T>::value,
80-
decltype(as_operation_cl(std::declval<T>())),
81-
std::remove_reference_t<decltype(as_operation_cl(std::declval<T>()))>>;
90+
template <typename T, assign_op_cl AssignOp = assign_op_cl::equals>
91+
using as_operation_cl_t
92+
= std::conditional_t<std::is_lvalue_reference<T>::value,
93+
decltype(as_operation_cl<AssignOp>(std::declval<T>())),
94+
std::remove_reference_t<decltype(
95+
as_operation_cl<AssignOp>(std::declval<T>()))>>;
8296

8397
/** @}*/
8498
} // namespace math
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_ASSIGNMENT_OPS
2+
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_ASSIGNMENT_OPS
3+
#ifdef STAN_OPENCL
4+
#include <stan/math/prim/meta/is_detected.hpp>
5+
6+
namespace stan {
7+
namespace math {
8+
9+
/**
10+
* Ops that decide the type of assignment for LHS operations
11+
*/
12+
enum class assign_op_cl {
13+
equals,
14+
plus_equals,
15+
minus_equals,
16+
divide_equals,
17+
multiply_equals
18+
};
19+
20+
namespace internal {
21+
/**
22+
* @param value A static constexpr const char* member for printing assignment
23+
* ops
24+
*/
25+
template <assign_op_cl assign_op>
26+
struct assignment_op_str_impl;
27+
28+
template <>
29+
struct assignment_op_str_impl<assign_op_cl::equals> {
30+
static constexpr const char* value = " = ";
31+
};
32+
33+
template <>
34+
struct assignment_op_str_impl<assign_op_cl::plus_equals> {
35+
static constexpr const char* value = " += ";
36+
};
37+
38+
template <>
39+
struct assignment_op_str_impl<assign_op_cl::minus_equals> {
40+
static constexpr const char* value = " -= ";
41+
};
42+
43+
template <>
44+
struct assignment_op_str_impl<assign_op_cl::divide_equals> {
45+
static constexpr const char* value = " /= ";
46+
};
47+
48+
template <>
49+
struct assignment_op_str_impl<assign_op_cl::multiply_equals> {
50+
static constexpr const char* value = " *= ";
51+
};
52+
53+
template <typename, typename = void>
54+
struct assignment_op_str : assignment_op_str_impl<assign_op_cl::equals> {};
55+
56+
template <typename T>
57+
struct assignment_op_str<T, void_t<decltype(T::assignment_op)>>
58+
: assignment_op_str_impl<T::assignment_op> {};
59+
60+
} // namespace internal
61+
62+
/**
63+
* @tparam T A type that has an `assignment_op` static constexpr member type
64+
* @return The types assignment op as a constexpr const char*
65+
*/
66+
template <typename T>
67+
inline constexpr const char* assignment_op() noexcept {
68+
return internal::assignment_op_str<std::decay_t<T>>::value;
69+
}
70+
71+
} // namespace math
72+
} // namespace stan
73+
#endif
74+
#endif

stan/math/opencl/kernel_generator/load.hpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include <stan/math/opencl/matrix_cl.hpp>
66
#include <stan/math/opencl/matrix_cl_view.hpp>
7+
#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
8+
79
#include <stan/math/opencl/kernel_generator/type_str.hpp>
810
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
911
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
@@ -23,17 +25,20 @@ namespace math {
2325
/**
2426
* Represents an access to a \c matrix_cl in kernel generator expressions
2527
* @tparam T \c matrix_cl
28+
* @tparam AssignOp tells higher level operations whether the final operation
29+
* should be an assignment or a type of compound assignment.
2630
*/
27-
template <typename T>
31+
template <typename T, assign_op_cl AssignOp = assign_op_cl::equals>
2832
class load_
29-
: public operation_cl_lhs<load_<T>,
33+
: public operation_cl_lhs<load_<T, AssignOp>,
3034
typename std::remove_reference_t<T>::type> {
3135
protected:
3236
T a_;
3337

3438
public:
39+
static constexpr assign_op_cl assignment_op = AssignOp;
3540
using Scalar = typename std::remove_reference_t<T>::type;
36-
using base = operation_cl<load_<T>, Scalar>;
41+
using base = operation_cl<load_<T, AssignOp>, Scalar>;
3742
using base::var_name_;
3843
static_assert(disjunction<is_matrix_cl<T>, is_arena_matrix_cl<T>>::value,
3944
"load_: argument a must be a matrix_cl<T>!");
@@ -51,9 +56,13 @@ class load_
5156
* Creates a deep copy of this expression.
5257
* @return copy of \c *this
5358
*/
54-
inline load_<T&> deep_copy() & { return load_<T&>(a_); }
55-
inline load_<const T&> deep_copy() const& { return load_<const T&>(a_); }
56-
inline load_<T> deep_copy() && { return load_<T>(std::forward<T>(a_)); }
59+
inline load_<T&, AssignOp> deep_copy() & { return load_<T&, AssignOp>(a_); }
60+
inline load_<const T&, AssignOp> deep_copy() const& {
61+
return load_<const T&, AssignOp>(a_);
62+
}
63+
inline load_<T, AssignOp> deep_copy() && {
64+
return load_<T, AssignOp>(std::forward<T>(a_));
65+
}
5766

5867
/**
5968
* Generates kernel code for this expression.
@@ -327,6 +336,7 @@ class load_
327336
}
328337
}
329338
};
339+
330340
/** @}*/
331341
} // namespace math
332342
} // namespace stan

0 commit comments

Comments
 (0)