Skip to content

Commit 7825fdf

Browse files
committed
Cleanup numerical stability, constants, compound functions in kernels
1 parent 5d1fd38 commit 7825fdf

8 files changed

Lines changed: 17 additions & 36 deletions

File tree

stan/math/opencl/kernels/device_functions/Phi.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ static const char* phi_device_function
2424
if (x < -37.5) {
2525
return 0;
2626
} else if (x < -5.0) {
27-
return 0.5 * erfc(-1.0 / sqrt(2.0) * x);
27+
return 0.5 * erfc(-M_SQRT1_2 * x);
2828
} else if (x > 8.25) {
2929
return 1;
3030
} else {
31-
return 0.5 * (1.0 + erf(1.0 / sqrt(2.0) * x));
31+
return 0.5 * (1.0 + erf(M_SQRT1_2 * x));
3232
}
3333
}
3434
// \cond

stan/math/opencl/kernels/device_functions/lbeta.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,21 +95,22 @@ static const char* lbeta_device_function
9595
return lgamma(x) + lgamma(y) - lgamma(x + y);
9696
}
9797
double x_over_xy = x / (x + y);
98+
double log_xpy = log(x + y);
9899
if (x < LGAMMA_STIRLING_DIFF_USEFUL) {
99100
// y large, x small
100101
double stirling_diff
101102
= lgamma_stirling_diff(y) - lgamma_stirling_diff(x + y);
102103
double stirling
103-
= (y - 0.5) * log1p(-x_over_xy) + x * (1 - log(x + y));
104+
= (y - 0.5) * log1p(-x_over_xy) + x * (1 - log_xpy);
104105
return stirling + lgamma(x) + stirling_diff;
105106
}
106107

107108
// both large
108109
double stirling_diff = lgamma_stirling_diff(x)
109110
+ lgamma_stirling_diff(y)
110111
- lgamma_stirling_diff(x + y);
111-
double stirling = (x - 0.5) * log(x_over_xy) + y * log1p(-x_over_xy)
112-
+ 0.5 * log(2.0 * M_PI) - 0.5 * log(y);
112+
double stirling = (x - 0.5) * (log(x) - log_xpy) + y * log1p(-x_over_xy)
113+
+ 0.5 * (M_LN2 + log(M_PI)) - 0.5 * log(y);
113114
return stirling + stirling_diff;
114115
}
115116
// \cond

stan/math/opencl/kernels/device_functions/lgamma_stirling.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ static const char* lgamma_stirling_device_function
2828
* @return Stirling's approximation to lgamma(x).
2929
*/
3030
double lgamma_stirling(double x) {
31-
return 0.5 * log(2.0 * M_PI) + (x - 0.5) * log(x) - x;
31+
return 0.5 * (M_LN2 + log(M_PI)) + (x - 0.5) * log(x) - x;
3232
}
3333
// \cond
3434
) "\n#endif\n"; // NOLINT

stan/math/opencl/kernels/device_functions/logit.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ static const char* logit_device_function
4949
* @param x argument
5050
* @return log odds of argument
5151
*/
52-
double logit(double x) { return log(x / (1 - x)); }
52+
double logit(double x) { return log(x) - log1m(x); }
5353
// \cond
5454
) "\n#endif\n"; // NOLINT
5555
// \endcond

stan/math/opencl/kernels/neg_binomial_2_log_glm_lpmf.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ static const char* neg_binomial_2_log_glm_kernel_code = STRINGIFY(
9292
double log_phi = log(phi);
9393
double logsumexp_theta_logphi;
9494
if (theta > log_phi) {
95-
logsumexp_theta_logphi = theta + log1p(exp(log_phi - theta));
95+
logsumexp_theta_logphi = theta + log1p_exp(log_phi - theta);
9696
} else {
97-
logsumexp_theta_logphi = log_phi + log1p(exp(theta - log_phi));
97+
logsumexp_theta_logphi = log_phi + log1p_exp(theta - log_phi);
9898
}
9999
double y_plus_phi = y + phi;
100100
if (need_logp1) {

stan/math/opencl/kernels/ordered_logistic_glm_lpmf.hpp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,10 @@ static const char* ordered_logistic_glm_kernel_code = STRINGIFY(
8787

8888
if (need_location_derivative || need_cuts_derivative) {
8989
double exp_cuts_diff = exp(cut_y2 - cut_y1);
90-
if (cut2 > 0) {
91-
double exp_m_cut2 = exp(-cut2);
92-
d1 = exp_m_cut2 / (1 + exp_m_cut2);
93-
} else {
94-
d1 = 1 / (1 + exp(cut2));
95-
}
90+
d1 = inv_logit(-cut2);
9691
d1 -= exp_cuts_diff / (exp_cuts_diff - 1);
9792
d2 = 1 / (1 - exp_cuts_diff);
98-
if (cut1 > 0) {
99-
double exp_m_cut1 = exp(-cut1);
100-
d2 -= exp_m_cut1 / (1 + exp_m_cut1);
101-
} else {
102-
d2 -= 1 / (1 + exp(cut1));
103-
}
93+
d2 -= inv_logit(-cut1);
10494

10595
if (need_location_derivative) {
10696
location_derivative[gid] = d1 - d2;

stan/math/opencl/kernels/ordered_logistic_lpmf.hpp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,10 @@ static const char* ordered_logistic_kernel_code = STRINGIFY(
8383

8484
if (need_lambda_derivative || need_cuts_derivative) {
8585
double exp_cuts_diff = exp(cut_y2 - cut_y1);
86-
if (cut2 > 0) {
87-
double exp_m_cut2 = exp(-cut2);
88-
d1 = exp_m_cut2 / (1 + exp_m_cut2);
89-
} else {
90-
d1 = 1 / (1 + exp(cut2));
91-
}
86+
d1 = inv_logit(-cut2);
9287
d1 -= exp_cuts_diff / (exp_cuts_diff - 1);
9388
d2 = 1 / (1 - exp_cuts_diff);
94-
if (cut1 > 0) {
95-
double exp_m_cut1 = exp(-cut1);
96-
d2 -= exp_m_cut1 / (1 + exp_m_cut1);
97-
} else {
98-
d2 -= 1 / (1 + exp(cut1));
99-
}
89+
d2 -= inv_logit(-cut1);
10090

10191
if (need_lambda_derivative) {
10292
lambda_derivative[gid] = d1 - d2;

stan/math/opencl/kernels/tridiagonalization.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,15 @@ static const char* tridiagonalization_householder_kernel_code = STRINGIFY(
8484
q = q_local[0];
8585
alpha = q_local[1];
8686
if (q != 0) {
87-
double multi = sqrt(2.) / q;
87+
double multi = M_SQRT2 / q;
8888
// normalize the Householder vector
8989
for (int i = lid + 1; i < P_span; i += lsize) {
9090
P[P_start + i] *= multi;
9191
}
9292
}
9393
if (gid == 0) {
9494
P[P_rows * (k + j + 1) + k + j]
95-
= P[P_rows * (k + j) + k + j + 1] * q / sqrt(2.) + alpha;
95+
= P[P_rows * (k + j) + k + j + 1] * q / M_SQRT2 + alpha;
9696
}
9797
}
9898
// \cond
@@ -291,7 +291,7 @@ static const char* tridiagonalization_v_step_3_kernel_code = STRINGIFY(
291291
v[i] -= acc * u[i];
292292
}
293293
if (gid == 0) {
294-
P[P_rows * (k + j + 1) + k + j] -= *q / sqrt(2.) * u[0];
294+
P[P_rows * (k + j + 1) + k + j] -= *q / M_SQRT2 * u[0];
295295
}
296296
}
297297
// \cond

0 commit comments

Comments
 (0)