|
| 1 | +#ifndef STAN_MATH_PRIM_PROB_BINOMIAL_LOGIT_GLM_LPMF_HPP |
| 2 | +#define STAN_MATH_PRIM_PROB_BINOMIAL_LOGIT_GLM_LPMF_HPP |
| 3 | + |
| 4 | +#include <stan/math/prim/meta.hpp> |
| 5 | +#include <stan/math/prim/err.hpp> |
| 6 | +#include <stan/math/prim/fun/Eigen.hpp> |
| 7 | +#include <stan/math/prim/fun/isfinite.hpp> |
| 8 | +#include <stan/math/prim/fun/size.hpp> |
| 9 | +#include <stan/math/prim/fun/size_zero.hpp> |
| 10 | +#include <stan/math/prim/fun/max_size.hpp> |
| 11 | +#include <stan/math/prim/fun/max.hpp> |
| 12 | +#include <stan/math/prim/fun/to_ref.hpp> |
| 13 | +#include <stan/math/prim/fun/value_of.hpp> |
| 14 | +#include <stan/math/prim/fun/as_value_column_array_or_scalar.hpp> |
| 15 | +#include <stan/math/prim/fun/as_value_column_vector_or_scalar.hpp> |
| 16 | +#include <stan/math/prim/functor/partials_propagator.hpp> |
| 17 | + |
| 18 | +namespace stan { |
| 19 | +namespace math { |
| 20 | + |
| 21 | +/** \ingroup prob_dists |
| 22 | + * Returns the log PMF of the Generalized Linear Model (GLM) |
| 23 | + * with Binomial distribution and logit link function. |
| 24 | + * The idea is that binomial_logit_glm_lpmf(n | N, x, alpha, beta) should |
| 25 | + * compute a more efficient version of |
| 26 | + * binomial_logit_lpmf(y | N, alpha + x * beta) by using analytically |
| 27 | + * simplified gradients. |
| 28 | + * If containers are supplied, returns the log sum of the probabilities. |
| 29 | + * |
| 30 | + * @tparam T_n type of binary vector of successes variables; |
| 31 | + * this can also be a single binary value; |
| 32 | + * @tparam T_N type of binary vector of population size variables; |
| 33 | + * this can also be a single binary value; |
| 34 | + * @tparam T_x type of the matrix of independent variables (features) |
| 35 | + * @tparam T_alpha type of the intercept(s); |
| 36 | + * this can be a vector (of the same length as y) of intercepts or a single |
| 37 | + * value (for models with constant intercept); |
| 38 | + * @tparam T_beta type of the weight vector |
| 39 | + * |
| 40 | + * @param n binary scalar or vector parameter. If it is a scalar it will be |
| 41 | + * broadcast - used for all instances. |
| 42 | + * @param N binary scalar or vector parameter. If it is a scalar it will be |
| 43 | + * broadcast - used for all instances. |
| 44 | + * @param x design matrix or row vector. If it is a row vector it will be |
| 45 | + * broadcast - used for all instances. |
| 46 | + * @param alpha intercept |
| 47 | + * @param beta weight vector |
| 48 | + * @return log probability or log sum of probabilities |
| 49 | + * @throw std::domain_error if x, beta or alpha is infinite. |
| 50 | + * @throw std::domain_error if n is negative or greater than N |
| 51 | + * @throw std::domain_error if N is negative |
| 52 | + * @throw std::invalid_argument if container sizes mismatch. |
| 53 | + */ |
| 54 | +template <bool propto, typename T_n, typename T_N, typename T_x, |
| 55 | + typename T_alpha, typename T_beta, require_matrix_t<T_x>* = nullptr> |
| 56 | +return_type_t<T_x, T_alpha, T_beta> binomial_logit_glm_lpmf( |
| 57 | + const T_n& n, const T_N& N, const T_x& x, const T_alpha& alpha, |
| 58 | + const T_beta& beta) { |
| 59 | + constexpr int T_x_rows = T_x::RowsAtCompileTime; |
| 60 | + using T_xbeta_partials = partials_return_t<T_x, T_beta>; |
| 61 | + using T_partials_return = partials_return_t<T_x, T_alpha, T_beta>; |
| 62 | + using T_xbeta_tmp = |
| 63 | + typename std::conditional_t<T_x_rows == 1, T_xbeta_partials, |
| 64 | + Eigen::Array<T_xbeta_partials, -1, 1>>; |
| 65 | + using T_n_ref = ref_type_if_t<!is_constant<T_n>::value, T_n>; |
| 66 | + using T_N_ref = ref_type_if_t<!is_constant<T_N>::value, T_N>; |
| 67 | + using T_x_ref = ref_type_if_t<!is_constant<T_x>::value, T_x>; |
| 68 | + using T_alpha_ref = ref_type_if_t<!is_constant<T_alpha>::value, T_alpha>; |
| 69 | + using T_beta_ref = ref_type_if_t<!is_constant<T_beta>::value, T_beta>; |
| 70 | + |
| 71 | + T_n_ref n_ref = n; |
| 72 | + T_N_ref N_ref = N; |
| 73 | + T_x_ref x_ref = x; |
| 74 | + T_alpha_ref alpha_ref = alpha; |
| 75 | + T_beta_ref beta_ref = beta; |
| 76 | + |
| 77 | + if (size_zero(n, N, alpha, beta, x)) { |
| 78 | + return 0; |
| 79 | + } |
| 80 | + |
| 81 | + if (!include_summand<propto, T_x, T_alpha, T_beta>::value) { |
| 82 | + return 0; |
| 83 | + } |
| 84 | + |
| 85 | + const size_t N_instances = max_size(n, N, x.col(0), alpha); |
| 86 | + const size_t N_attributes = x.cols(); |
| 87 | + |
| 88 | + static const char* function = "binomial_logit_glm_lpmf"; |
| 89 | + check_consistent_sizes(function, "Successes variable", n, |
| 90 | + "Population size parameter", N); |
| 91 | + check_consistent_size(function, "Successes variable", n, N_instances); |
| 92 | + check_consistent_size(function, "Population size parameter", N, N_instances); |
| 93 | + check_consistent_size(function, "Weight vector", beta, N_attributes); |
| 94 | + check_consistent_size(function, "Vector of intercepts", alpha, N_instances); |
| 95 | + |
| 96 | + auto&& n_val = as_value_column_array_or_scalar(n_ref); |
| 97 | + auto&& N_val = as_value_column_array_or_scalar(N_ref); |
| 98 | + |
| 99 | + check_bounded(function, "Successes variable", n_val, 0, N_val); |
| 100 | + check_nonnegative(function, "Population size parameter", N_val); |
| 101 | + |
| 102 | + auto&& alpha_val = as_value_column_array_or_scalar(alpha_ref); |
| 103 | + auto&& beta_val = as_value_column_vector_or_scalar(beta_ref); |
| 104 | + auto&& x_val = value_of(x_ref); |
| 105 | + Eigen::Array<T_partials_return, -1, 1> theta(N_instances); |
| 106 | + if (T_x_rows == 1) { |
| 107 | + theta = forward_as<T_xbeta_tmp>((x_val * beta_val)(0, 0)) + alpha_val; |
| 108 | + } else { |
| 109 | + theta = (x_val * beta_val).array() + alpha_val; |
| 110 | + } |
| 111 | + |
| 112 | + constexpr bool gradients_calc = !is_constant_all<T_beta, T_x, T_alpha>::value; |
| 113 | + auto&& log_inv_logit_theta = to_ref_if<gradients_calc>(log_inv_logit(theta)); |
| 114 | + |
| 115 | + T_partials_return logp = sum(n_val * log_inv_logit_theta |
| 116 | + + (N_val - n_val) * log1m_inv_logit(theta)); |
| 117 | + |
| 118 | + using std::isfinite; |
| 119 | + if (!isfinite(logp)) { |
| 120 | + check_finite(function, "Weight vector", beta); |
| 121 | + check_finite(function, "Intercept", alpha); |
| 122 | + check_finite(function, "Matrix of independent variables", x); |
| 123 | + } |
| 124 | + |
| 125 | + if (include_summand<propto, T_n, T_N>::value) { |
| 126 | + size_t broadcast_n = max_size(N, n) == N_instances ? 1 : N_instances; |
| 127 | + logp += sum(binomial_coefficient_log(N_val, n_val)) * broadcast_n; |
| 128 | + } |
| 129 | + |
| 130 | + auto ops_partials = make_partials_propagator(x_ref, alpha_ref, beta_ref); |
| 131 | + if (gradients_calc) { |
| 132 | + Eigen::Matrix<T_partials_return, -1, 1> theta_derivative |
| 133 | + = n_val - N_val * exp(log_inv_logit_theta); |
| 134 | + |
| 135 | + if (!is_constant_all<T_beta>::value) { |
| 136 | + if (T_x_rows == 1) { |
| 137 | + edge<2>(ops_partials).partials_ |
| 138 | + = forward_as<Eigen::Matrix<T_partials_return, 1, -1>>( |
| 139 | + theta_derivative.sum() * x_val); |
| 140 | + } else { |
| 141 | + partials<2>(ops_partials) = x_val.transpose() * theta_derivative; |
| 142 | + } |
| 143 | + } |
| 144 | + |
| 145 | + if (!is_constant_all<T_x>::value) { |
| 146 | + if (T_x_rows == 1) { |
| 147 | + edge<0>(ops_partials).partials_ |
| 148 | + = forward_as<Eigen::Array<T_partials_return, -1, T_x_rows>>( |
| 149 | + beta_val * theta_derivative.sum()); |
| 150 | + } else { |
| 151 | + edge<0>(ops_partials).partials_ |
| 152 | + = (beta_val * theta_derivative.transpose()).transpose(); |
| 153 | + } |
| 154 | + } |
| 155 | + if (!is_constant_all<T_alpha>::value) { |
| 156 | + partials<1>(ops_partials) = theta_derivative; |
| 157 | + } |
| 158 | + } |
| 159 | + return ops_partials.build(logp); |
| 160 | +} |
| 161 | + |
| 162 | +template <typename T_n, typename T_N, typename T_x, typename T_alpha, |
| 163 | + typename T_beta> |
| 164 | +inline return_type_t<T_x, T_beta, T_alpha> binomial_logit_glm_lpmf( |
| 165 | + const T_n& n, const T_N& N, const T_x& x, const T_alpha& alpha, |
| 166 | + const T_beta& beta) { |
| 167 | + return binomial_logit_glm_lpmf<false>(n, N, x, alpha, beta); |
| 168 | +} |
| 169 | +} // namespace math |
| 170 | +} // namespace stan |
| 171 | +#endif |
0 commit comments