File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -24,11 +24,11 @@ static const char* phi_device_function
2424 if (x < -37.5 ) {
2525 return 0 ;
2626 } else if (x < -5.0 ) {
27- return 0.5 * erfc (-1.0 / sqrt ( 2.0 ) * x);
27+ return 0.5 * erfc (-M_SQRT1_2 * x);
2828 } else if (x > 8.25 ) {
2929 return 1 ;
3030 } else {
31- return 0.5 * (1.0 + erf (1.0 / sqrt ( 2.0 ) * x));
31+ return 0.5 * (1.0 + erf (M_SQRT1_2 * x));
3232 }
3333 }
3434 // \cond
Original file line number Diff line number Diff line change @@ -95,21 +95,22 @@ static const char* lbeta_device_function
9595 return lgamma (x) + lgamma (y) - lgamma (x + y);
9696 }
9797 double x_over_xy = x / (x + y);
98+ double log_xpy = log (x + y);
9899 if (x < LGAMMA_STIRLING_DIFF_USEFUL) {
99100 // y large, x small
100101 double stirling_diff
101102 = lgamma_stirling_diff (y) - lgamma_stirling_diff (x + y);
102103 double stirling
103- = (y - 0.5 ) * log1p (-x_over_xy) + x * (1 - log (x + y) );
104+ = (y - 0.5 ) * log1p (-x_over_xy) + x * (1 - log_xpy );
104105 return stirling + lgamma (x) + stirling_diff;
105106 }
106107
107108 // both large
108109 double stirling_diff = lgamma_stirling_diff (x)
109110 + lgamma_stirling_diff (y)
110111 - lgamma_stirling_diff (x + y);
111- double stirling = (x - 0.5 ) * log (x_over_xy ) + y * log1p (-x_over_xy)
112- + 0.5 * log ( 2.0 * M_PI) - 0.5 * log (y);
112+ double stirling = (x - 0.5 ) * ( log (x) - log_xpy ) + y * log1p (-x_over_xy)
113+ + 0.5 * (M_LN2 + log ( M_PI) ) - 0.5 * log (y);
113114 return stirling + stirling_diff;
114115 }
115116 // \cond
Original file line number Diff line number Diff line change @@ -28,7 +28,7 @@ static const char* lgamma_stirling_device_function
2828 * @return Stirling's approximation to lgamma(x).
2929 */
3030 double lgamma_stirling (double x) {
31- return 0.5 * log ( 2.0 * M_PI) + (x - 0.5 ) * log (x) - x;
31+ return 0.5 * (M_LN2 + log ( M_PI) ) + (x - 0.5 ) * log (x) - x;
3232 }
3333 // \cond
3434 ) " \n #endif\n " ; // NOLINT
Original file line number Diff line number Diff line change @@ -49,7 +49,7 @@ static const char* logit_device_function
4949 * @param x argument
5050 * @return log odds of argument
5151 */
52- double logit (double x) { return log (x / ( 1 - x) ); }
52+ double logit (double x) { return log (x) - log1m (x ); }
5353 // \cond
5454 ) " \n #endif\n " ; // NOLINT
5555// \endcond
Original file line number Diff line number Diff line change @@ -92,9 +92,9 @@ static const char* neg_binomial_2_log_glm_kernel_code = STRINGIFY(
9292 double log_phi = log (phi);
9393 double logsumexp_theta_logphi;
9494 if (theta > log_phi) {
95- logsumexp_theta_logphi = theta + log1p ( exp ( log_phi - theta) );
95+ logsumexp_theta_logphi = theta + log1p_exp ( log_phi - theta);
9696 } else {
97- logsumexp_theta_logphi = log_phi + log1p ( exp ( theta - log_phi) );
97+ logsumexp_theta_logphi = log_phi + log1p_exp ( theta - log_phi);
9898 }
9999 double y_plus_phi = y + phi;
100100 if (need_logp1) {
Original file line number Diff line number Diff line change @@ -87,20 +87,10 @@ static const char* ordered_logistic_glm_kernel_code = STRINGIFY(
8787
8888 if (need_location_derivative || need_cuts_derivative) {
8989 double exp_cuts_diff = exp (cut_y2 - cut_y1);
90- if (cut2 > 0 ) {
91- double exp_m_cut2 = exp (-cut2);
92- d1 = exp_m_cut2 / (1 + exp_m_cut2);
93- } else {
94- d1 = 1 / (1 + exp (cut2));
95- }
90+ d1 = inv_logit (-cut2);
9691 d1 -= exp_cuts_diff / (exp_cuts_diff - 1 );
9792 d2 = 1 / (1 - exp_cuts_diff);
98- if (cut1 > 0 ) {
99- double exp_m_cut1 = exp (-cut1);
100- d2 -= exp_m_cut1 / (1 + exp_m_cut1);
101- } else {
102- d2 -= 1 / (1 + exp (cut1));
103- }
93+ d2 -= inv_logit (-cut1);
10494
10595 if (need_location_derivative) {
10696 location_derivative[gid] = d1 - d2;
Original file line number Diff line number Diff line change @@ -83,20 +83,10 @@ static const char* ordered_logistic_kernel_code = STRINGIFY(
8383
8484 if (need_lambda_derivative || need_cuts_derivative) {
8585 double exp_cuts_diff = exp (cut_y2 - cut_y1);
86- if (cut2 > 0 ) {
87- double exp_m_cut2 = exp (-cut2);
88- d1 = exp_m_cut2 / (1 + exp_m_cut2);
89- } else {
90- d1 = 1 / (1 + exp (cut2));
91- }
86+ d1 = inv_logit (-cut2);
9287 d1 -= exp_cuts_diff / (exp_cuts_diff - 1 );
9388 d2 = 1 / (1 - exp_cuts_diff);
94- if (cut1 > 0 ) {
95- double exp_m_cut1 = exp (-cut1);
96- d2 -= exp_m_cut1 / (1 + exp_m_cut1);
97- } else {
98- d2 -= 1 / (1 + exp (cut1));
99- }
89+ d2 -= inv_logit (-cut1);
10090
10191 if (need_lambda_derivative) {
10292 lambda_derivative[gid] = d1 - d2;
Original file line number Diff line number Diff line change @@ -84,15 +84,15 @@ static const char* tridiagonalization_householder_kernel_code = STRINGIFY(
8484 q = q_local[0 ];
8585 alpha = q_local[1 ];
8686 if (q != 0 ) {
87- double multi = sqrt ( 2 .) / q;
87+ double multi = M_SQRT2 / q;
8888 // normalize the Householder vector
8989 for (int i = lid + 1 ; i < P_span; i += lsize) {
9090 P[P_start + i] *= multi;
9191 }
9292 }
9393 if (gid == 0 ) {
9494 P[P_rows * (k + j + 1 ) + k + j]
95- = P[P_rows * (k + j) + k + j + 1 ] * q / sqrt ( 2 .) + alpha;
95+ = P[P_rows * (k + j) + k + j + 1 ] * q / M_SQRT2 + alpha;
9696 }
9797 }
9898 // \cond
@@ -291,7 +291,7 @@ static const char* tridiagonalization_v_step_3_kernel_code = STRINGIFY(
291291 v[i] -= acc * u[i];
292292 }
293293 if (gid == 0 ) {
294- P[P_rows * (k + j + 1 ) + k + j] -= *q / sqrt ( 2 .) * u[0 ];
294+ P[P_rows * (k + j + 1 ) + k + j] -= *q / M_SQRT2 * u[0 ];
295295 }
296296 }
297297 // \cond
You can’t perform that action at this time.
0 commit comments