|
5 | 5 | #include <stan/math/opencl/kernel_cl.hpp> |
6 | 6 | #include <stan/math/opencl/kernels/device_functions/log1m_exp.hpp> |
7 | 7 | #include <stan/math/opencl/kernels/device_functions/log1p_exp.hpp> |
| 8 | +#include <stan/math/opencl/kernels/device_functions/inv_logit.hpp> |
8 | 9 |
|
9 | 10 | namespace stan { |
10 | 11 | namespace math { |
@@ -83,20 +84,10 @@ static const char* ordered_logistic_kernel_code = STRINGIFY( |
83 | 84 |
|
84 | 85 | if (need_lambda_derivative || need_cuts_derivative) { |
85 | 86 | double exp_cuts_diff = exp(cut_y2 - cut_y1); |
86 | | - if (cut2 > 0) { |
87 | | - double exp_m_cut2 = exp(-cut2); |
88 | | - d1 = exp_m_cut2 / (1 + exp_m_cut2); |
89 | | - } else { |
90 | | - d1 = 1 / (1 + exp(cut2)); |
91 | | - } |
| 87 | + d1 = inv_logit(-cut2); |
92 | 88 | d1 -= exp_cuts_diff / (exp_cuts_diff - 1); |
93 | 89 | d2 = 1 / (1 - exp_cuts_diff); |
94 | | - if (cut1 > 0) { |
95 | | - double exp_m_cut1 = exp(-cut1); |
96 | | - d2 -= exp_m_cut1 / (1 + exp_m_cut1); |
97 | | - } else { |
98 | | - d2 -= 1 / (1 + exp(cut1)); |
99 | | - } |
| 90 | + d2 -= inv_logit(-cut1); |
100 | 91 |
|
101 | 92 | if (need_lambda_derivative) { |
102 | 93 | lambda_derivative[gid] = d1 - d2; |
@@ -175,7 +166,7 @@ const kernel_cl<out_buffer, out_buffer, out_buffer, in_buffer, in_buffer, |
175 | 166 | in_buffer, int, int, int, int, int, int> |
176 | 167 | ordered_logistic("ordered_logistic", |
177 | 168 | {log1p_exp_device_function, log1m_exp_device_function, |
178 | | - ordered_logistic_kernel_code}, |
| 169 | + inv_logit_device_function, ordered_logistic_kernel_code}, |
179 | 170 | {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}}); |
180 | 171 |
|
181 | 172 | } // namespace opencl_kernels |
|
0 commit comments