Skip to content

Commit 23dafa2

Browse files
nsicchaclaude
andcommitted
Add tuple support to apply_scalar_unary for autodiff
Add tuple specialization to apply_scalar_unary so that vectorized math functions (exp, sin, cos, etc.) automatically work on tuples. Also extend require_ad_container_t to accept tuples containing autodiff types, and add is_autodiff support for tuples. Fixes #3041 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 8ee3d96 commit 23dafa2

3 files changed

Lines changed: 51 additions & 3 deletions

File tree

stan/math/prim/functor/apply_scalar_unary.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
#include <stan/math/prim/fun/Eigen.hpp>
55
#include <stan/math/prim/meta/is_eigen.hpp>
66
#include <stan/math/prim/meta/is_complex.hpp>
7+
#include <stan/math/prim/meta/is_tuple.hpp>
78
#include <stan/math/prim/meta/holder.hpp>
89
#include <stan/math/prim/meta/require_generics.hpp>
910
#include <stan/math/prim/meta/is_vector.hpp>
1011
#include <stan/math/prim/meta/is_vector_like.hpp>
1112
#include <stan/math/prim/meta/plain_type.hpp>
13+
#include <tuple>
1214
#include <utility>
1315
#include <vector>
1416

@@ -206,6 +208,36 @@ struct apply_scalar_unary<F, T, require_std_vector_t<T>> {
206208
}
207209
};
208210

211+
/**
212+
* Template specialization for vectorized functions applying to
213+
* tuple arguments. Each element of the tuple is processed
214+
* recursively through apply_scalar_unary, allowing heterogeneous
215+
* element types.
216+
*
217+
* @tparam F Type of function defining static apply function.
218+
* @tparam T Tuple type.
219+
*/
220+
template <typename F, typename T>
221+
struct apply_scalar_unary<F, T, require_tuple_t<T>> {
222+
template <typename TT, size_t... Is>
223+
static inline auto apply_impl(TT&& x, std::index_sequence<Is...>) {
224+
return std::make_tuple(
225+
apply_scalar_unary<
226+
F, std::tuple_element_t<Is, std::decay_t<T>>>::
227+
apply(std::get<Is>(std::forward<TT>(x)))...);
228+
}
229+
230+
template <typename TT>
231+
static inline auto apply(TT&& x) {
232+
return apply_impl(
233+
std::forward<TT>(x),
234+
std::make_index_sequence<std::tuple_size_v<std::decay_t<T>>>());
235+
}
236+
237+
using return_t = std::decay_t<decltype(
238+
apply_scalar_unary<F, T>::apply(std::declval<T>()))>;
239+
};
240+
209241
} // namespace math
210242
} // namespace stan
211243
#endif

stan/math/prim/meta/is_autodiff.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
#include <stan/math/prim/meta/is_fvar.hpp>
99
#include <stan/math/prim/meta/is_vector.hpp>
1010
#include <stan/math/prim/meta/is_var.hpp>
11+
#include <stan/math/prim/meta/is_tuple.hpp>
1112
#include <stan/math/prim/meta/require_helpers.hpp>
1213
#include <stan/math/prim/meta/scalar_type.hpp>
1314
#include <stan/math/prim/meta/value_type.hpp>
1415
#include <complex>
16+
#include <tuple>
1517
#include <type_traits>
1618

1719
namespace stan {
@@ -46,6 +48,17 @@ template <typename T>
4648
struct is_autodiff<T, require_eigen_t<T>>
4749
: bool_constant<is_autodiff<typename std::decay_t<T>::Scalar>::value> {};
4850

51+
template <typename Tuple, typename = void>
52+
struct is_tuple_autodiff : std::false_type {};
53+
54+
template <typename... Ts>
55+
struct is_tuple_autodiff<std::tuple<Ts...>, void>
56+
: bool_constant<(is_autodiff<Ts>::value || ...)> {};
57+
58+
template <typename T>
59+
struct is_autodiff<T, math::require_tuple_t<T>>
60+
: is_tuple_autodiff<std::decay_t<T>> {};
61+
4962
} // namespace internal
5063

5164
/**

stan/math/prim/meta/is_container.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <stan/math/prim/meta/disjunction.hpp>
66
#include <stan/math/prim/meta/is_autodiff.hpp>
77
#include <stan/math/prim/meta/is_eigen.hpp>
8+
#include <stan/math/prim/meta/is_tuple.hpp>
89
#include <stan/math/prim/meta/is_vector.hpp>
910
#include <stan/math/prim/meta/is_var_matrix.hpp>
1011
#include <stan/math/prim/meta/base_type.hpp>
@@ -99,9 +100,11 @@ using require_not_container_st
99100
/*! and holds a base type that satisfies @ref is_autodiff_scalar */
100101
/*! @tparam T the type to check */
101102
template <typename T>
102-
using require_ad_container_t
103-
= require_all_t<stan::math::disjunction<is_eigen<T>, is_std_vector<T>>,
104-
is_autodiff_scalar<base_type_t<T>>>;
103+
using require_ad_container_t = require_t<math::disjunction<
104+
math::conjunction<math::disjunction<is_eigen<T>, is_std_vector<T>>,
105+
is_autodiff_scalar<base_type_t<T>>>,
106+
math::conjunction<math::is_tuple<T>,
107+
internal::is_autodiff<std::decay_t<T>>>>>;
105108
/*! @} */
106109

107110
} // namespace stan

0 commit comments

Comments
 (0)