@@ -45,12 +45,12 @@ namespace internal {
4545 * @return Three-element tuple containing gradients w.r.t. a1, a2, and b1,
4646 * as indicated by the calc_a1, calc_a2, and calc_b1 booleans
4747 */
48- template <bool calc_a1, bool calc_a2, bool calc_b1,
49- typename T1, typename T2, typename T3, typename T_z,
48+ template <bool calc_a1, bool calc_a2, bool calc_b1, typename T1, typename T2,
49+ typename T3, typename T_z,
5050 typename ScalarT = return_type_t <T1, T2, T3, T_z>,
5151 typename TupleT = std::tuple<ScalarT, ScalarT, ScalarT>>
5252TupleT grad_2F1_impl_ab (const T1& a1, const T2& a2, const T3& b1, const T_z& z,
53- double precision = 1e-14 , int max_steps = 1e6 ) {
53+ double precision = 1e-14 , int max_steps = 1e6 ) {
5454 TupleT grad_tuple = TupleT (0 , 0 , 0 );
5555
5656 if (z == 0 ) {
@@ -131,7 +131,6 @@ TupleT grad_2F1_impl_ab(const T1& a1, const T2& a2, const T3& b1, const T_z& z,
131131 return grad_tuple;
132132}
133133
134-
135134/* *
136135 * Implementation function to calculate the gradients of the hypergeometric
137136 * function, 2F1.
@@ -164,20 +163,20 @@ TupleT grad_2F1_impl_ab(const T1& a1, const T2& a2, const T3& b1, const T_z& z,
164163 * @return Four-element tuple containing gradients w.r.t. to each parameter,
165164 * as indicated by the calc_* booleans
166165 */
167- template <bool calc_a1, bool calc_a2, bool calc_b1, bool calc_z,
168- typename T1, typename T2, typename T3, typename T_z,
166+ template <bool calc_a1, bool calc_a2, bool calc_b1, bool calc_z, typename T1,
167+ typename T2, typename T3, typename T_z,
169168 typename ScalarT = return_type_t <T1, T2, T3, T_z>,
170169 typename TupleT = std::tuple<ScalarT, ScalarT, ScalarT, ScalarT>>
171170TupleT grad_2F1_impl (const T1& a1, const T2& a2, const T3& b1, const T_z& z,
172- double precision = 1e-14 , int max_steps = 1e6 ) {
171+ double precision = 1e-14 , int max_steps = 1e6 ) {
173172 bool euler_transform = false ;
174173 try {
175174 check_2F1_converges (" hypergeometric_2F1" , a1, a2, b1, z);
176175 } catch (const std::exception& e) {
177176 // Apply Euler's hypergeometric transformation if function
178177 // will not converge with current arguments
179- check_2F1_converges (" hypergeometric_2F1 (euler transform)" ,
180- b1 - a1, a2, b1, z / (z - 1 ));
178+ check_2F1_converges (" hypergeometric_2F1 (euler transform)" , b1 - a1, a2, b1,
179+ z / (z - 1 ));
181180 euler_transform = true ;
182181 }
183182
@@ -191,10 +190,11 @@ TupleT grad_2F1_impl(const T1& a1, const T2& a2, const T3& b1, const T_z& z,
191190 auto hyper1 = hypergeometric_2F1 (a1_euler, a2_euler, b1, z_euler);
192191 auto hyper2 = hypergeometric_2F1 (1 + a2, 1 - a1 + b1, 1 + b1, z_euler);
193192 auto pre_mult = a2 * pow (1 - z, -1 - a2);
194- std::get<3 >(grad_tuple_rtn) = a2 * pow (1 - z, -1 - a2) * hyper1
195- + (a2 * (b1 - a1) * pow (1 - z, -a2)
196- * (inv (z - 1 ) - z / square (z - 1 )) * hyper2)
197- / b1;
193+ std::get<3 >(grad_tuple_rtn)
194+ = a2 * pow (1 - z, -1 - a2) * hyper1
195+ + (a2 * (b1 - a1) * pow (1 - z, -a2)
196+ * (inv (z - 1 ) - z / square (z - 1 )) * hyper2)
197+ / b1;
198198 }
199199 if (calc_a1 || calc_a2 || calc_b1) {
200200 // 'a' gradients under Euler transform are constructed using the gradients
@@ -203,22 +203,22 @@ TupleT grad_2F1_impl(const T1& a1, const T2& a2, const T3& b1, const T_z& z,
203203 // 'b' gradients under Euler transform require gradients from 'a2'
204204 constexpr bool calc_a2_euler = calc_a1 || calc_a2 || calc_b1;
205205 grad_tuple_ab = grad_2F1_impl_ab<calc_a1_euler, calc_a2_euler, calc_b1>(
206- a1_euler, a2_euler, b1, z_euler);
206+ a1_euler, a2_euler, b1, z_euler);
207207
208208 auto pre_mult_ab = inv (pow (1.0 - z, a2));
209209 if (calc_a1) {
210210 std::get<0 >(grad_tuple_rtn) = -pre_mult_ab * std::get<1 >(grad_tuple_ab);
211211 }
212212 if (calc_a2) {
213- auto hyper_da2
214- = hypergeometric_2F1 (a1_euler, a2, b1, z_euler);
213+ auto hyper_da2 = hypergeometric_2F1 (a1_euler, a2, b1, z_euler);
215214 std::get<1 >(grad_tuple_rtn)
216215 = -pre_mult_ab * hyper_da2 * log1m (z)
217216 + pre_mult_ab * std::get<0 >(grad_tuple_ab);
218217 }
219218 if (calc_b1) {
220- std::get<2 >(grad_tuple_rtn) = pre_mult_ab
221- * (std::get<1 >(grad_tuple_ab) + std::get<2 >(grad_tuple_ab));
219+ std::get<2 >(grad_tuple_rtn)
220+ = pre_mult_ab
221+ * (std::get<1 >(grad_tuple_ab) + std::get<2 >(grad_tuple_ab));
222222 }
223223 }
224224 } else {
@@ -228,7 +228,7 @@ TupleT grad_2F1_impl(const T1& a1, const T2& a2, const T3& b1, const T_z& z,
228228 }
229229 if (calc_a1 || calc_a2 || calc_b1) {
230230 grad_tuple_ab
231- = grad_2F1_impl_ab<calc_a1, calc_a2, calc_b1>(a1, a2, b1, z);
231+ = grad_2F1_impl_ab<calc_a1, calc_a2, calc_b1>(a1, a2, b1, z);
232232 if (calc_a1) {
233233 std::get<0 >(grad_tuple_rtn) = std::get<0 >(grad_tuple_ab);
234234 }
@@ -273,9 +273,8 @@ auto grad_2F1(const T1& a1, const T2& a2, const T3& b1, const T_z& z,
273273 double precision = 1e-14 , int max_steps = 1e6 ) {
274274 return internal::grad_2F1_impl<
275275 !is_constant<T1>::value, !is_constant<T2>::value, !is_constant<T3>::value,
276- !is_constant<T_z>::value>(value_of (a1), value_of (a2),
277- value_of (b1), value_of (z), precision,
278- max_steps);
276+ !is_constant<T_z>::value>(value_of (a1), value_of (a2), value_of (b1),
277+ value_of (z), precision, max_steps);
279278}
280279
281280/* *
@@ -305,8 +304,8 @@ template <bool ReturnSameT, typename T1, typename T2, typename T3, typename T_z,
305304 require_t <std::integral_constant<bool , ReturnSameT>>* = nullptr >
306305auto grad_2F1 (const T1& a1, const T2& a2, const T3& b1, const T_z& z,
307306 double precision = 1e-14 , int max_steps = 1e6 ) {
308- return internal::grad_2F1_impl<true , true , true , true >(
309- a1, a2, b1, z, precision, max_steps);
307+ return internal::grad_2F1_impl<true , true , true , true >(a1, a2, b1, z,
308+ precision, max_steps);
310309}
311310
312311/* *
0 commit comments