Skip to content

Commit fdb6d34

Browse files
committed
Fixes adjoint accumulation for reverse mode where aliasing can occur. Creates a assignment op tag that is used by adjoint_results to do a += instead of a = into the adjoint matrix
1 parent 5091bf9 commit fdb6d34

13 files changed

Lines changed: 275 additions & 104 deletions

File tree

stan/math/opencl/kernel_generator.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
109109
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
110110
#include <stan/math/opencl/kernel_generator/type_str.hpp>
111-
111+
#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
112112
#include <stan/math/opencl/kernel_generator/as_column_vector_or_scalar.hpp>
113113
#include <stan/math/opencl/kernel_generator/load.hpp>
114114
#include <stan/math/opencl/kernel_generator/scalar.hpp>

stan/math/opencl/kernel_generator/as_operation_cl.hpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_AS_OPERATION_CL_HPP
33
#ifdef STAN_OPENCL
44

5+
#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
56
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
67
#include <stan/math/opencl/kernel_generator/load.hpp>
78
#include <stan/math/opencl/kernel_generator/scalar.hpp>
@@ -23,7 +24,7 @@ namespace math {
2324
* @param a an operation
2425
* @return operation
2526
*/
26-
template <typename T_operation,
27+
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_operation,
2728
typename = std::enable_if_t<std::is_base_of<
2829
operation_cl_base, std::remove_reference_t<T_operation>>::value>>
2930
inline T_operation&& as_operation_cl(T_operation&& a) {
@@ -37,7 +38,7 @@ inline T_operation&& as_operation_cl(T_operation&& a) {
3738
* @param a scalar
3839
* @return \c scalar_ wrapping the input
3940
*/
40-
template <typename T_scalar, typename = require_arithmetic_t<T_scalar>,
41+
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_scalar, typename = require_arithmetic_t<T_scalar>,
4142
require_not_same_t<T_scalar, bool>* = nullptr>
4243
inline scalar_<T_scalar> as_operation_cl(const T_scalar a) {
4344
return scalar_<T_scalar>(a);
@@ -50,6 +51,7 @@ inline scalar_<T_scalar> as_operation_cl(const T_scalar a) {
5051
* @param a scalar
5152
* @return \c scalar_<char> wrapping the input
5253
*/
54+
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals>
5355
inline scalar_<char> as_operation_cl(const bool a) { return scalar_<char>(a); }
5456

5557
/**
@@ -59,11 +61,11 @@ inline scalar_<char> as_operation_cl(const bool a) { return scalar_<char>(a); }
5961
* @param a \c matrix_cl
6062
* @return \c load_ wrapping the input
6163
*/
62-
template <typename T_matrix_cl,
64+
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_matrix_cl,
6365
typename = require_any_t<is_matrix_cl<T_matrix_cl>,
6466
is_arena_matrix_cl<T_matrix_cl>>>
65-
inline load_<T_matrix_cl> as_operation_cl(T_matrix_cl&& a) {
66-
return load_<T_matrix_cl>(std::forward<T_matrix_cl>(a));
67+
inline load_<T_matrix_cl, AssignOp> as_operation_cl(T_matrix_cl&& a) {
68+
return load_<T_matrix_cl, AssignOp>(std::forward<T_matrix_cl>(a));
6769
}
6870

6971
/**
@@ -74,11 +76,11 @@ inline load_<T_matrix_cl> as_operation_cl(T_matrix_cl&& a) {
7476
* rvalue reference, the reference is removed, so that a variable of this type
7577
* actually stores the value.
7678
*/
77-
template <typename T>
79+
template <typename T, assignment_ops_cl AssignOp = assignment_ops_cl::equals>
7880
using as_operation_cl_t = std::conditional_t<
7981
std::is_lvalue_reference<T>::value,
80-
decltype(as_operation_cl(std::declval<T>())),
81-
std::remove_reference_t<decltype(as_operation_cl(std::declval<T>()))>>;
82+
decltype(as_operation_cl<AssignOp>(std::declval<T>())),
83+
std::remove_reference_t<decltype(as_operation_cl<AssignOp>(std::declval<T>()))>>;
8284

8385
/** @}*/
8486
} // namespace math
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_ASSIGNMENT_OPS
2+
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_ASSIGNMENT_OPS
3+
#ifdef STAN_OPENCL
4+
#include <stan/math/prim/meta/is_detected.hpp>
5+
6+
namespace stan {
7+
namespace math {
8+
9+
/**
10+
* Ops that decide the type of assignment for LHS operations
11+
*/
12+
enum class assignment_ops_cl {equals, plus_equals, minus_equals, divide_equals};
13+
14+
/**
15+
* @param value A static constexpr const char* member for printing assignment ops
16+
*/
17+
template <assignment_ops_cl assign_op>
18+
struct assignment_op_str;
19+
20+
template <>
21+
struct assignment_op_str<assignment_ops_cl::equals> {
22+
static constexpr const char* value = " = ";
23+
};
24+
25+
template <>
26+
struct assignment_op_str<assignment_ops_cl::plus_equals> {
27+
static constexpr const char* value = " += ";
28+
};
29+
30+
template <>
31+
struct assignment_op_str<assignment_ops_cl::minus_equals> {
32+
static constexpr const char* value = " *= ";
33+
};
34+
35+
template <>
36+
struct assignment_op_str<assignment_ops_cl::divide_equals> {
37+
static constexpr const char* value = " /= ";
38+
};
39+
40+
41+
namespace internal {
42+
template <typename, typename = void>
43+
struct has_assignment_op_str : std::false_type {};
44+
45+
template <typename T>
46+
struct has_assignment_op_str<T, void_t<decltype(T::assignment_op)>> : std::true_type {};
47+
48+
} // namespace internal
49+
50+
/**
51+
* @tparam T A type that does not have an `assignment_op` static constexpr member type
52+
* @return A constexpr const char* equal to `" = "`
53+
*/
54+
template <typename T, std::enable_if_t<!internal::has_assignment_op_str<std::decay_t<T>>::value>* = nullptr>
55+
inline constexpr const char* assignment_op() noexcept {
56+
return " = ";
57+
}
58+
59+
/**
60+
* @tparam T A type that has an `assignment_op` static constexpr member type
61+
* @return The types assignment op as a constexpr const char*
62+
*/
63+
template <typename T, std::enable_if_t<internal::has_assignment_op_str<T>::value>* = nullptr>
64+
inline constexpr const char* assignment_op() noexcept {
65+
return assignment_op_str<std::decay_t<T>::assignment_op>::value;
66+
}
67+
68+
}
69+
}
70+
#endif
71+
#endif

stan/math/opencl/kernel_generator/load.hpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include <stan/math/opencl/matrix_cl.hpp>
66
#include <stan/math/opencl/matrix_cl_view.hpp>
7+
#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
8+
79
#include <stan/math/opencl/kernel_generator/type_str.hpp>
810
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
911
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
@@ -23,17 +25,20 @@ namespace math {
2325
/**
2426
* Represents an access to a \c matrix_cl in kernel generator expressions
2527
* @tparam T \c matrix_cl
28+
* @tparam AssignOp tells higher level operations whether the final operation should be an assignment or a type of compound assignment.
2629
*/
27-
template <typename T>
30+
template <typename T, assignment_ops_cl AssignOp = assignment_ops_cl::equals>
2831
class load_
29-
: public operation_cl_lhs<load_<T>,
32+
: public operation_cl_lhs<load_<T, AssignOp>,
3033
typename std::remove_reference_t<T>::type> {
3134
protected:
3235
T a_;
3336

3437
public:
38+
39+
static constexpr assignment_ops_cl assignment_op = AssignOp;
3540
using Scalar = typename std::remove_reference_t<T>::type;
36-
using base = operation_cl<load_<T>, Scalar>;
41+
using base = operation_cl<load_<T, AssignOp>, Scalar>;
3742
using base::var_name_;
3843
static_assert(disjunction<is_matrix_cl<T>, is_arena_matrix_cl<T>>::value,
3944
"load_: argument a must be a matrix_cl<T>!");
@@ -51,9 +56,9 @@ class load_
5156
* Creates a deep copy of this expression.
5257
* @return copy of \c *this
5358
*/
54-
inline load_<T&> deep_copy() & { return load_<T&>(a_); }
55-
inline load_<const T&> deep_copy() const& { return load_<const T&>(a_); }
56-
inline load_<T> deep_copy() && { return load_<T>(std::forward<T>(a_)); }
59+
inline load_<T&, AssignOp> deep_copy() & { return load_<T&, AssignOp>(a_); }
60+
inline load_<const T&, AssignOp> deep_copy() const& { return load_<const T&, AssignOp>(a_); }
61+
inline load_<T, AssignOp> deep_copy() && { return load_<T, AssignOp>(std::forward<T>(a_)); }
5762

5863
/**
5964
* Generates kernel code for this expression.
@@ -327,6 +332,7 @@ class load_
327332
}
328333
}
329334
};
335+
330336
/** @}*/
331337
} // namespace math
332338
} // namespace stan

stan/math/opencl/kernel_generator/multi_result_kernel.hpp

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <stan/math/prim/err/check_size_match.hpp>
66
#include <stan/math/prim/meta/is_kernel_expression.hpp>
77
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
8+
#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
89
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
910
#include <stan/math/opencl/kernel_generator/calc_if.hpp>
1011
#include <stan/math/opencl/kernel_generator/check_cl.hpp>
@@ -334,13 +335,34 @@ class results_cl {
334335
== sizeof...(T_expressions)>>
335336
void operator+=(const expressions_cl<T_expressions...>& exprs) {
336337
index_apply<sizeof...(T_expressions)>([this, &exprs](auto... Is) {
337-
auto tmp = std::tuple_cat(make_assignment_pair(
338+
auto tmp = std::tuple_cat(make_assignment_pair<assignment_ops_cl::plus_equals>(
338339
std::get<Is>(results_), std::get<Is>(exprs.expressions_))...);
339340
index_apply<std::tuple_size<decltype(tmp)>::value>(
340341
[this, &tmp](auto... Is2) {
341342
assignment_impl(std::make_tuple(std::make_pair(
342-
std::get<Is2>(tmp).first,
343-
std::get<Is2>(tmp).first + std::get<Is2>(tmp).second)...));
343+
std::get<Is2>(tmp).first, std::get<Is2>(tmp).second)...));
344+
});
345+
});
346+
}
347+
348+
/**
349+
* Incrementing \c results_ object by \c expressions_cl object
350+
* executes the kernel that evaluates expressions and increments results by
351+
* those expressions.
352+
* @tparam T_expressions types of expressions
353+
* @param exprs expressions
354+
*/
355+
template <typename... T_expressions,
356+
typename = std::enable_if_t<sizeof...(T_results)
357+
== sizeof...(T_expressions)>>
358+
void operator-=(const expressions_cl<T_expressions...>& exprs) {
359+
index_apply<sizeof...(T_expressions)>([this, &exprs](auto... Is) {
360+
auto tmp = std::tuple_cat(make_assignment_pair<assignment_ops_cl::minus_equals>(
361+
std::get<Is>(results_), std::get<Is>(exprs.expressions_))...);
362+
index_apply<std::tuple_size<decltype(tmp)>::value>(
363+
[this, &tmp](auto... Is2) {
364+
assignment_impl(std::make_tuple(std::make_pair(
365+
std::get<Is2>(tmp).first, std::get<Is2>(tmp).second)...));
344366
});
345367
});
346368
}
@@ -426,7 +448,7 @@ class results_cl {
426448
+ parts.reduction_2d +
427449
"}\n";
428450
}
429-
return src;
451+
return src;
430452
}
431453

432454
/**
@@ -529,16 +551,16 @@ class results_cl {
529551
* @param expression expression
530552
* @return a tuple of pair of result and expression
531553
*/
532-
template <typename T_result, typename T_expression,
554+
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_result, typename T_expression,
533555
require_all_not_t<is_without_output<T_expression>,
534556
conjunction<internal::is_scalar_check<T_result>,
535557
std::is_arithmetic<std::decay_t<
536558
T_expression>>>>* = nullptr>
537559
static auto make_assignment_pair(T_result&& result,
538560
T_expression&& expression) {
539561
return std::make_tuple(
540-
std::pair<as_operation_cl_t<T_result>, as_operation_cl_t<T_expression>>(
541-
as_operation_cl(std::forward<T_result>(result)),
562+
std::pair<as_operation_cl_t<T_result, AssignOp>, as_operation_cl_t<T_expression>>(
563+
as_operation_cl<AssignOp>(std::forward<T_result>(result)),
542564
as_operation_cl(std::forward<T_expression>(expression))));
543565
}
544566

@@ -548,7 +570,7 @@ class results_cl {
548570
* @param expression expression
549571
* @return a tuple of pair of result and expression
550572
*/
551-
template <typename T_result, typename T_expression,
573+
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_result, typename T_expression,
552574
require_t<is_without_output<T_expression>>* = nullptr>
553575
static auto make_assignment_pair(T_result&& result,
554576
T_expression&& expression) {
@@ -562,7 +584,7 @@ class results_cl {
562584
* @param pass bool scalar
563585
* @return an empty tuple
564586
*/
565-
template <typename T_check, typename T_pass,
587+
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_check, typename T_pass,
566588
require_t<internal::is_scalar_check<T_check>>* = nullptr,
567589
require_integral_t<T_pass>* = nullptr>
568590
static std::tuple<> make_assignment_pair(T_check&& result, T_pass&& pass) {

stan/math/opencl/kernel_generator/operation_cl.hpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <stan/math/prim/meta.hpp>
66
#include <stan/math/prim/err/check_nonnegative.hpp>
7+
#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
78
#include <stan/math/opencl/kernel_generator/type_str.hpp>
89
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
910
#include <stan/math/opencl/matrix_cl_view.hpp>
@@ -74,6 +75,24 @@ struct kernel_parts {
7475
}
7576
};
7677

78+
std::ostream& operator<<(std::ostream& os, kernel_parts& parts) {
79+
os << "args:" << std::endl;
80+
os << parts.args.substr(0, parts.args.size() - 2) << std::endl;
81+
os << "Decl:" << std::endl;
82+
os << parts.declarations << std::endl;
83+
os << "Init:" << std::endl;
84+
os << parts.initialization << std::endl;
85+
os << "body:" << std::endl;
86+
os << parts.body << std::endl;
87+
os << "body_suffix:" << std::endl;
88+
os << parts.body_suffix << std::endl;
89+
os << "reduction_1d:" << std::endl;
90+
os << parts.reduction_1d << std::endl;
91+
os << "reduction_2d:" << std::endl;
92+
os << parts.reduction_2d << std::endl;
93+
return os;
94+
}
95+
7796
/**
7897
* Base for all kernel generator operations.
7998
* @tparam Derived derived type
@@ -201,7 +220,7 @@ class operation_cl : public operation_cl_base {
201220
generated, generated_all, ng, row_index_name, col_index_name, false);
202221
kernel_parts out_parts = result.get_kernel_parts_lhs(
203222
generated, generated_all, ng, row_index_name, col_index_name);
204-
out_parts.body += " = " + derived().var_name_ + ";\n";
223+
out_parts.body += assignment_op<T_result>() + derived().var_name_ + ";\n";
205224
parts += out_parts;
206225
return parts;
207226
}

stan/math/opencl/prim/normal_lccdf.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,12 @@ return_type_t<T_y_cl, T_loc_cl, T_scale_cl> normal_lccdf(
8282
matrix_cl<double> mu_deriv_cl;
8383
matrix_cl<double> sigma_deriv_cl;
8484

85-
results(check_y_not_nan, check_mu_finite, check_sigma_positive, lccdf_cl,
86-
y_deriv_cl, mu_deriv_cl, sigma_deriv_cl)
87-
= expressions(y_not_nan_expr, mu_finite_expr, sigma_positive_expr,
88-
lccdf_expr, calc_if<!is_constant<T_y_cl>::value>(y_deriv),
85+
results(check_y_not_nan, check_mu_finite, check_sigma_positive)
86+
= expressions(y_not_nan_expr, mu_finite_expr, sigma_positive_expr);
87+
results(lccdf_cl, y_deriv_cl, mu_deriv_cl, sigma_deriv_cl)
88+
= expressions(lccdf_expr, calc_if<!is_constant<T_y_cl>::value>(y_deriv),
8989
calc_if<!is_constant<T_loc_cl>::value>(mu_deriv),
9090
calc_if<!is_constant<T_scale_cl>::value>(sigma_deriv));
91-
9291
T_partials_return lccdf = LOG_HALF + sum(from_matrix_cl(lccdf_cl));
9392

9493
auto ops_partials = make_partials_propagator(y_col, mu_col, sigma_col);

stan/math/opencl/rev.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
#include <stan/math/opencl/rev/fmax.hpp>
5151
#include <stan/math/opencl/rev/fmin.hpp>
5252
#include <stan/math/opencl/rev/fmod.hpp>
53+
#include <stan/math/opencl/rev/grad.hpp>
5354
#include <stan/math/opencl/rev/hypot.hpp>
5455
#include <stan/math/opencl/rev/inv.hpp>
5556
#include <stan/math/opencl/rev/inv_cloglog.hpp>

0 commit comments

Comments
 (0)