Skip to content

Commit 119ca9b

Browse files
committed
Merge commit 'eb3b5d769f60e93e79be51a2c723ce368e8437f9' into HEAD
2 parents b4d9d30 + eb3b5d7 commit 119ca9b

18 files changed

Lines changed: 169 additions & 161 deletions

stan/math/opencl/kernel_generator/elt_function_cl.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_square,
307307
opencl_kernels::inv_square_device_function)
308308
ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_logit,
309309
opencl_kernels::inv_logit_device_function)
310-
ADD_UNARY_FUNCTION_WITH_INCLUDES(logit, opencl_kernels::logit_device_function)
310+
ADD_UNARY_FUNCTION_WITH_INCLUDES(logit, opencl_kernels::log1m_device_function,
311+
opencl_kernels::logit_device_function)
311312
ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi, opencl_kernels::phi_device_function)
312313
ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi_approx,
313314
opencl_kernels::inv_logit_device_function,

stan/math/opencl/kernels/device_functions/Phi.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ static const char* phi_device_function
2424
if (x < -37.5) {
2525
return 0;
2626
} else if (x < -5.0) {
27-
return 0.5 * erfc(-1.0 / sqrt(2.0) * x);
27+
return 0.5 * erfc(-M_SQRT1_2 * x);
2828
} else if (x > 8.25) {
2929
return 1;
3030
} else {
31-
return 0.5 * (1.0 + erf(1.0 / sqrt(2.0) * x));
31+
return 0.5 * (1.0 + erf(M_SQRT1_2 * x));
3232
}
3333
}
3434
// \cond

stan/math/opencl/kernels/device_functions/inv_logit.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ static const char* inv_logit_device_function
5656
*/
5757
double inv_logit(double x) {
5858
if (x < 0) {
59-
if (x < log(2.2204460492503131E-16)) {
59+
if (x < log(DBL_EPSILON)) {
6060
return exp(x);
6161
}
6262
return exp(x) / (1 + exp(x));

stan/math/opencl/kernels/device_functions/lbeta.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,21 +95,23 @@ static const char* lbeta_device_function
9595
return lgamma(x) + lgamma(y) - lgamma(x + y);
9696
}
9797
double x_over_xy = x / (x + y);
98+
double log_xpy = log(x + y);
9899
if (x < LGAMMA_STIRLING_DIFF_USEFUL) {
99100
// y large, x small
100101
double stirling_diff
101102
= lgamma_stirling_diff(y) - lgamma_stirling_diff(x + y);
102103
double stirling
103-
= (y - 0.5) * log1p(-x_over_xy) + x * (1 - log(x + y));
104+
= (y - 0.5) * log1p(-x_over_xy) + x * (1 - log_xpy);
104105
return stirling + lgamma(x) + stirling_diff;
105106
}
106107

107108
// both large
108109
double stirling_diff = lgamma_stirling_diff(x)
109110
+ lgamma_stirling_diff(y)
110111
- lgamma_stirling_diff(x + y);
111-
double stirling = (x - 0.5) * log(x_over_xy) + y * log1p(-x_over_xy)
112-
+ 0.5 * log(2.0 * M_PI) - 0.5 * log(y);
112+
double stirling = (x - 0.5) * (log(x) - log_xpy)
113+
+ y * log1p(-x_over_xy)
114+
+ 0.5 * (M_LN2 + log(M_PI)) - 0.5 * log(y);
113115
return stirling + stirling_diff;
114116
}
115117
// \cond

stan/math/opencl/kernels/device_functions/lgamma_stirling.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ static const char* lgamma_stirling_device_function
2828
* @return Stirling's approximation to lgamma(x).
2929
*/
3030
double lgamma_stirling(double x) {
31-
return 0.5 * log(2.0 * M_PI) + (x - 0.5) * log(x) - x;
31+
return 0.5 * (M_LN2 + log(M_PI)) + (x - 0.5) * log(x) - x;
3232
}
3333
// \cond
3434
) "\n#endif\n"; // NOLINT

stan/math/opencl/kernels/device_functions/logit.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ static const char* logit_device_function
4949
* @param x argument
5050
* @return log odds of argument
5151
*/
52-
double logit(double x) { return log(x / (1 - x)); }
52+
double logit(double x) { return log(x) - log1m(x); }
5353
// \cond
5454
) "\n#endif\n"; // NOLINT
5555
// \endcond

stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <stan/math/opencl/kernel_cl.hpp>
66
#include <stan/math/opencl/kernels/device_functions/digamma.hpp>
7+
#include <stan/math/opencl/kernels/device_functions/log1p_exp.hpp>
78

89
namespace stan {
910
namespace math {
@@ -92,9 +93,9 @@ static const char* neg_binomial_2_log_glm_kernel_code = STRINGIFY(
9293
double log_phi = log(phi);
9394
double logsumexp_theta_logphi;
9495
if (theta > log_phi) {
95-
logsumexp_theta_logphi = theta + log1p(exp(log_phi - theta));
96+
logsumexp_theta_logphi = theta + log1p_exp(log_phi - theta);
9697
} else {
97-
logsumexp_theta_logphi = log_phi + log1p(exp(theta - log_phi));
98+
logsumexp_theta_logphi = log_phi + log1p_exp(theta - log_phi);
9899
}
99100
double y_plus_phi = y + phi;
100101
if (need_logp1) {
@@ -196,7 +197,7 @@ const kernel_cl<out_buffer, out_buffer, out_buffer, out_buffer, in_buffer,
196197
in_buffer, in_buffer, in_buffer, in_buffer, int, int, int, int,
197198
int, int, int, int, int, int, int, int, int>
198199
neg_binomial_2_log_glm("neg_binomial_2_log_glm",
199-
{digamma_device_function,
200+
{digamma_device_function, log1p_exp_device_function,
200201
neg_binomial_2_log_glm_kernel_code},
201202
{{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});
202203

stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <stan/math/opencl/kernel_cl.hpp>
66
#include <stan/math/opencl/kernels/device_functions/log1m_exp.hpp>
77
#include <stan/math/opencl/kernels/device_functions/log1p_exp.hpp>
8+
#include <stan/math/opencl/kernels/device_functions/inv_logit.hpp>
89

910
namespace stan {
1011
namespace math {
@@ -87,20 +88,10 @@ static const char* ordered_logistic_glm_kernel_code = STRINGIFY(
8788

8889
if (need_location_derivative || need_cuts_derivative) {
8990
double exp_cuts_diff = exp(cut_y2 - cut_y1);
90-
if (cut2 > 0) {
91-
double exp_m_cut2 = exp(-cut2);
92-
d1 = exp_m_cut2 / (1 + exp_m_cut2);
93-
} else {
94-
d1 = 1 / (1 + exp(cut2));
95-
}
91+
d1 = inv_logit(-cut2);
9692
d1 -= exp_cuts_diff / (exp_cuts_diff - 1);
9793
d2 = 1 / (1 - exp_cuts_diff);
98-
if (cut1 > 0) {
99-
double exp_m_cut1 = exp(-cut1);
100-
d2 -= exp_m_cut1 / (1 + exp_m_cut1);
101-
} else {
102-
d2 -= 1 / (1 + exp(cut1));
103-
}
94+
d2 -= inv_logit(-cut1);
10495

10596
if (need_location_derivative) {
10697
location_derivative[gid] = d1 - d2;
@@ -181,6 +172,7 @@ const kernel_cl<out_buffer, out_buffer, out_buffer, out_buffer, in_buffer,
181172
in_buffer, in_buffer, in_buffer, int, int, int, int, int, int>
182173
ordered_logistic_glm("ordered_logistic_glm",
183174
{log1p_exp_device_function, log1m_exp_device_function,
175+
inv_logit_device_function,
184176
ordered_logistic_glm_kernel_code},
185177
{{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});
186178

stan/math/opencl/kernels/ordered_logistic_lpmf.hpp

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <stan/math/opencl/kernel_cl.hpp>
66
#include <stan/math/opencl/kernels/device_functions/log1m_exp.hpp>
77
#include <stan/math/opencl/kernels/device_functions/log1p_exp.hpp>
8+
#include <stan/math/opencl/kernels/device_functions/inv_logit.hpp>
89

910
namespace stan {
1011
namespace math {
@@ -83,20 +84,10 @@ static const char* ordered_logistic_kernel_code = STRINGIFY(
8384

8485
if (need_lambda_derivative || need_cuts_derivative) {
8586
double exp_cuts_diff = exp(cut_y2 - cut_y1);
86-
if (cut2 > 0) {
87-
double exp_m_cut2 = exp(-cut2);
88-
d1 = exp_m_cut2 / (1 + exp_m_cut2);
89-
} else {
90-
d1 = 1 / (1 + exp(cut2));
91-
}
87+
d1 = inv_logit(-cut2);
9288
d1 -= exp_cuts_diff / (exp_cuts_diff - 1);
9389
d2 = 1 / (1 - exp_cuts_diff);
94-
if (cut1 > 0) {
95-
double exp_m_cut1 = exp(-cut1);
96-
d2 -= exp_m_cut1 / (1 + exp_m_cut1);
97-
} else {
98-
d2 -= 1 / (1 + exp(cut1));
99-
}
90+
d2 -= inv_logit(-cut1);
10091

10192
if (need_lambda_derivative) {
10293
lambda_derivative[gid] = d1 - d2;
@@ -175,7 +166,7 @@ const kernel_cl<out_buffer, out_buffer, out_buffer, in_buffer, in_buffer,
175166
in_buffer, int, int, int, int, int, int>
176167
ordered_logistic("ordered_logistic",
177168
{log1p_exp_device_function, log1m_exp_device_function,
178-
ordered_logistic_kernel_code},
169+
inv_logit_device_function, ordered_logistic_kernel_code},
179170
{{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});
180171

181172
} // namespace opencl_kernels

stan/math/opencl/kernels/tridiagonalization.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,15 @@ static const char* tridiagonalization_householder_kernel_code = STRINGIFY(
8484
q = q_local[0];
8585
alpha = q_local[1];
8686
if (q != 0) {
87-
double multi = sqrt(2.) / q;
87+
double multi = M_SQRT2 / q;
8888
// normalize the Householder vector
8989
for (int i = lid + 1; i < P_span; i += lsize) {
9090
P[P_start + i] *= multi;
9191
}
9292
}
9393
if (gid == 0) {
9494
P[P_rows * (k + j + 1) + k + j]
95-
= P[P_rows * (k + j) + k + j + 1] * q / sqrt(2.) + alpha;
95+
= P[P_rows * (k + j) + k + j + 1] * q / M_SQRT2 + alpha;
9696
}
9797
}
9898
// \cond
@@ -291,7 +291,7 @@ static const char* tridiagonalization_v_step_3_kernel_code = STRINGIFY(
291291
v[i] -= acc * u[i];
292292
}
293293
if (gid == 0) {
294-
P[P_rows * (k + j + 1) + k + j] -= *q / sqrt(2.) * u[0];
294+
P[P_rows * (k + j + 1) + k + j] -= *q / M_SQRT2 * u[0];
295295
}
296296
}
297297
// \cond

0 commit comments

Comments
 (0)