Skip to content

Commit ad303aa

Browse files
committed
Merge branch 'develop' into fvar-support
2 parents f8d886d + 598dba7 commit ad303aa

18 files changed

Lines changed: 451 additions & 49 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ lib/tbb
7171

7272
# local make include
7373
/make/local
74+
/make/ucrt
7475

7576
# python byte code
7677
*.pyc

make/compiler_flags

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,21 @@ ifeq ($(OS),Windows_NT)
161161
CXXFLAGS_OS ?= -m64
162162
endif
163163

164+
make/ucrt:
165+
pound := \#
166+
UCRT_STRING := $(shell echo '$(pound)include <windows.h>' | $(CXX) -E -dM - | findstr _UCRT)
167+
ifneq (,$(UCRT_STRING))
168+
IS_UCRT ?= true
169+
else
170+
IS_UCRT ?= false
171+
endif
172+
$(shell echo "IS_UCRT ?= $(IS_UCRT)" > $(MATH)make/ucrt)
173+
174+
include make/ucrt
175+
ifeq ($(IS_UCRT),true)
176+
CXXFLAGS_OS += -D_UCRT
177+
endif
178+
164179
ifneq (gcc,$(CXX_TYPE))
165180
LDLIBS_OS ?= -static-libgcc
166181
else

make/libraries

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ ifeq (Linux, $(OS))
139139
SHELL = /usr/bin/env bash
140140
endif
141141

142+
ifeq (Windows_NT, $(OS))
143+
ifeq ($(IS_UCRT),true)
144+
TBB_CXXFLAGS += -D_UCRT
145+
endif
146+
endif
147+
142148
# If brackets or spaces are found in MAKE on Windows
143149
# we error, as those characters cause issues when building.
144150
ifeq (Windows_NT, $(OS))

make/tests

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,15 @@ HEADER_TESTS := $(addsuffix -test,$(call findfiles,stan,*.hpp))
101101

102102
ifeq ($(OS),Windows_NT)
103103
DEV_NULL = nul
104+
ifeq ($(IS_UCRT),true)
105+
UCRT_NULL_FLAG = -S
106+
endif
104107
else
105108
DEV_NULL = /dev/null
106109
endif
107110

108111
%.hpp-test : %.hpp test/dummy.cpp
109-
$(COMPILE.cpp) $(CXXFLAGS) -O0 -include $^ -o $(DEV_NULL) -Wunused-local-typedefs
112+
$(COMPILE.cpp) $(CXXFLAGS) -O0 -include $^ $(UCRT_NULL_FLAG) -o $(DEV_NULL) -Wunused-local-typedefs
110113

111114
test/dummy.cpp:
112115
@mkdir -p test

makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ clean-deps:
125125
@$(RM) $(call findfiles,test,*.d.*)
126126
@$(RM) $(call findfiles,lib,*.d.*)
127127
@$(RM) $(call findfiles,stan,*.dSYM)
128+
@$(RM) $(call findfiles,make,ucrt)
128129

129130
clean-all: clean clean-doxygen clean-deps clean-libraries
130131

stan/math/opencl/kernel_generator.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
109109
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
110110
#include <stan/math/opencl/kernel_generator/type_str.hpp>
111-
111+
#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
112112
#include <stan/math/opencl/kernel_generator/as_column_vector_or_scalar.hpp>
113113
#include <stan/math/opencl/kernel_generator/load.hpp>
114114
#include <stan/math/opencl/kernel_generator/scalar.hpp>

stan/math/opencl/kernel_generator/as_operation_cl.hpp

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_AS_OPERATION_CL_HPP
33
#ifdef STAN_OPENCL
44

5+
#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
56
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
67
#include <stan/math/opencl/kernel_generator/load.hpp>
78
#include <stan/math/opencl/kernel_generator/scalar.hpp>
@@ -19,11 +20,12 @@ namespace math {
1920
/**
2021
* Converts any valid kernel generator expression into an operation. This is an
2122
* overload for operations - a no-op
23+
* @tparam AssignOp ignored
2224
* @tparam T_operation type of the input operation
2325
* @param a an operation
2426
* @return operation
2527
*/
26-
template <typename T_operation,
28+
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_operation,
2729
typename = std::enable_if_t<std::is_base_of<
2830
operation_cl_base, std::remove_reference_t<T_operation>>::value>>
2931
inline T_operation&& as_operation_cl(T_operation&& a) {
@@ -33,11 +35,13 @@ inline T_operation&& as_operation_cl(T_operation&& a) {
3335
/**
3436
* Converts any valid kernel generator expression into an operation. This is an
3537
* overload for scalars (arithmetic types). It wraps them into \c scalar_.
38+
* @tparam AssignOp ignored
3639
* @tparam T_scalar type of the input scalar
3740
* @param a scalar
3841
* @return \c scalar_ wrapping the input
3942
*/
40-
template <typename T_scalar, typename = require_arithmetic_t<T_scalar>,
43+
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_scalar,
44+
typename = require_arithmetic_t<T_scalar>,
4145
require_not_same_t<T_scalar, bool>* = nullptr>
4246
inline scalar_<T_scalar> as_operation_cl(const T_scalar a) {
4347
return scalar_<T_scalar>(a);
@@ -47,23 +51,29 @@ inline scalar_<T_scalar> as_operation_cl(const T_scalar a) {
4751
* Converts any valid kernel generator expression into an operation. This is an
4852
* overload for bool scalars. It wraps them into \c scalar_<char> as \c bool can
4953
* not be used as a type of a kernel argument.
54+
* @tparam AssignOp ignored
5055
* @param a scalar
5156
* @return \c scalar_<char> wrapping the input
5257
*/
53-
inline scalar_<char> as_operation_cl(const bool a) { return scalar_<char>(a); }
58+
template <assign_op_cl AssignOp = assign_op_cl::equals>
59+
inline scalar_<char> as_operation_cl(const bool a) {
60+
return scalar_<char>(a);
61+
}
5462

5563
/**
5664
* Converts any valid kernel generator expression into an operation. This is an
5765
* overload for \c matrix_cl. It wraps them into into \c load_.
66+
* @tparam AssignOp an optional `assign_op_cl` that dictates whether the object
67+
* is assigned using standard or compound assign.
5868
* @tparam T_matrix_cl \c matrix_cl
5969
* @param a \c matrix_cl
6070
* @return \c load_ wrapping the input
6171
*/
62-
template <typename T_matrix_cl,
72+
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_matrix_cl,
6373
typename = require_any_t<is_matrix_cl<T_matrix_cl>,
6474
is_arena_matrix_cl<T_matrix_cl>>>
65-
inline load_<T_matrix_cl> as_operation_cl(T_matrix_cl&& a) {
66-
return load_<T_matrix_cl>(std::forward<T_matrix_cl>(a));
75+
inline load_<T_matrix_cl, AssignOp> as_operation_cl(T_matrix_cl&& a) {
76+
return load_<T_matrix_cl, AssignOp>(std::forward<T_matrix_cl>(a));
6777
}
6878

6979
/**
@@ -73,12 +83,16 @@ inline load_<T_matrix_cl> as_operation_cl(T_matrix_cl&& a) {
7383
* as_operation_cl_t<T>. If the return value of \c as_operation_cl() would be a
7484
* rvalue reference, the reference is removed, so that a variable of this type
7585
* actually stores the value.
86+
* @tparam T a `matrix_cl` or `Scalar` type
87+
* @tparam AssignOp an optional `assign_op_cl` that dictates whether the object
88+
* is assigned using standard or compound assign.
7689
*/
77-
template <typename T>
78-
using as_operation_cl_t = std::conditional_t<
79-
std::is_lvalue_reference<T>::value,
80-
decltype(as_operation_cl(std::declval<T>())),
81-
std::remove_reference_t<decltype(as_operation_cl(std::declval<T>()))>>;
90+
template <typename T, assign_op_cl AssignOp = assign_op_cl::equals>
91+
using as_operation_cl_t
92+
= std::conditional_t<std::is_lvalue_reference<T>::value,
93+
decltype(as_operation_cl<AssignOp>(std::declval<T>())),
94+
std::remove_reference_t<decltype(
95+
as_operation_cl<AssignOp>(std::declval<T>()))>>;
8296

8397
/** @}*/
8498
} // namespace math
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_ASSIGNMENT_OPS
2+
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_ASSIGNMENT_OPS
3+
#ifdef STAN_OPENCL
4+
#include <stan/math/prim/meta/is_detected.hpp>
5+
6+
namespace stan {
7+
namespace math {
8+
9+
/**
10+
* Ops that decide the type of assignment for LHS operations
11+
*/
12+
enum class assign_op_cl {
13+
equals,
14+
plus_equals,
15+
minus_equals,
16+
divide_equals,
17+
multiply_equals
18+
};
19+
20+
namespace internal {
21+
/**
22+
* @param value A static constexpr const char* member for printing assignment
23+
* ops
24+
*/
25+
template <assign_op_cl assign_op>
26+
struct assignment_op_str_impl;
27+
28+
template <>
29+
struct assignment_op_str_impl<assign_op_cl::equals> {
30+
static constexpr const char* value = " = ";
31+
};
32+
33+
template <>
34+
struct assignment_op_str_impl<assign_op_cl::plus_equals> {
35+
static constexpr const char* value = " += ";
36+
};
37+
38+
template <>
39+
struct assignment_op_str_impl<assign_op_cl::minus_equals> {
40+
static constexpr const char* value = " -= ";
41+
};
42+
43+
template <>
44+
struct assignment_op_str_impl<assign_op_cl::divide_equals> {
45+
static constexpr const char* value = " /= ";
46+
};
47+
48+
template <>
49+
struct assignment_op_str_impl<assign_op_cl::multiply_equals> {
50+
static constexpr const char* value = " *= ";
51+
};
52+
53+
template <typename, typename = void>
54+
struct assignment_op_str : assignment_op_str_impl<assign_op_cl::equals> {};
55+
56+
template <typename T>
57+
struct assignment_op_str<T, void_t<decltype(T::assignment_op)>>
58+
: assignment_op_str_impl<T::assignment_op> {};
59+
60+
} // namespace internal
61+
62+
/**
63+
* @tparam T A type that has an `assignment_op` static constexpr member type
64+
* @return The types assignment op as a constexpr const char*
65+
*/
66+
template <typename T>
67+
inline constexpr const char* assignment_op() noexcept {
68+
return internal::assignment_op_str<std::decay_t<T>>::value;
69+
}
70+
71+
} // namespace math
72+
} // namespace stan
73+
#endif
74+
#endif

stan/math/opencl/kernel_generator/load.hpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include <stan/math/opencl/matrix_cl.hpp>
66
#include <stan/math/opencl/matrix_cl_view.hpp>
7+
#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
8+
79
#include <stan/math/opencl/kernel_generator/type_str.hpp>
810
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
911
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
@@ -23,17 +25,20 @@ namespace math {
2325
/**
2426
* Represents an access to a \c matrix_cl in kernel generator expressions
2527
* @tparam T \c matrix_cl
28+
* @tparam AssignOp tells higher level operations whether the final operation
29+
* should be an assignment or a type of compound assignment.
2630
*/
27-
template <typename T>
31+
template <typename T, assign_op_cl AssignOp = assign_op_cl::equals>
2832
class load_
29-
: public operation_cl_lhs<load_<T>,
33+
: public operation_cl_lhs<load_<T, AssignOp>,
3034
typename std::remove_reference_t<T>::type> {
3135
protected:
3236
T a_;
3337

3438
public:
39+
static constexpr assign_op_cl assignment_op = AssignOp;
3540
using Scalar = typename std::remove_reference_t<T>::type;
36-
using base = operation_cl<load_<T>, Scalar>;
41+
using base = operation_cl<load_<T, AssignOp>, Scalar>;
3742
using base::var_name_;
3843
static_assert(disjunction<is_matrix_cl<T>, is_arena_matrix_cl<T>>::value,
3944
"load_: argument a must be a matrix_cl<T>!");
@@ -51,9 +56,13 @@ class load_
5156
* Creates a deep copy of this expression.
5257
* @return copy of \c *this
5358
*/
54-
inline load_<T&> deep_copy() & { return load_<T&>(a_); }
55-
inline load_<const T&> deep_copy() const& { return load_<const T&>(a_); }
56-
inline load_<T> deep_copy() && { return load_<T>(std::forward<T>(a_)); }
59+
inline load_<T&, AssignOp> deep_copy() & { return load_<T&, AssignOp>(a_); }
60+
inline load_<const T&, AssignOp> deep_copy() const& {
61+
return load_<const T&, AssignOp>(a_);
62+
}
63+
inline load_<T, AssignOp> deep_copy() && {
64+
return load_<T, AssignOp>(std::forward<T>(a_));
65+
}
5766

5867
/**
5968
* Generates kernel code for this expression.
@@ -327,6 +336,7 @@ class load_
327336
}
328337
}
329338
};
339+
330340
/** @}*/
331341
} // namespace math
332342
} // namespace stan

0 commit comments

Comments
 (0)