Skip to content

Commit c90b3bf

Browse files
authored
Merge pull request #2946 from stan-dev/binomial-logit-glm
Add binomial_logit_glm distribution
2 parents e5879f6 + b891472 commit c90b3bf

4 files changed

Lines changed: 288 additions & 0 deletions

File tree

stan/math/prim/prob.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include <stan/math/prim/prob/binomial_log.hpp>
4747
#include <stan/math/prim/prob/binomial_logit_log.hpp>
4848
#include <stan/math/prim/prob/binomial_logit_lpmf.hpp>
49+
#include <stan/math/prim/prob/binomial_logit_glm_lpmf.hpp>
4950
#include <stan/math/prim/prob/binomial_lpmf.hpp>
5051
#include <stan/math/prim/prob/binomial_rng.hpp>
5152
#include <stan/math/prim/prob/categorical_log.hpp>
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
#ifndef STAN_MATH_PRIM_PROB_BINOMIAL_LOGIT_GLM_LPMF_HPP
2+
#define STAN_MATH_PRIM_PROB_BINOMIAL_LOGIT_GLM_LPMF_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/Eigen.hpp>
7+
#include <stan/math/prim/fun/isfinite.hpp>
8+
#include <stan/math/prim/fun/size.hpp>
9+
#include <stan/math/prim/fun/size_zero.hpp>
10+
#include <stan/math/prim/fun/max_size.hpp>
11+
#include <stan/math/prim/fun/max.hpp>
12+
#include <stan/math/prim/fun/to_ref.hpp>
13+
#include <stan/math/prim/fun/value_of.hpp>
14+
#include <stan/math/prim/fun/as_value_column_array_or_scalar.hpp>
15+
#include <stan/math/prim/fun/as_value_column_vector_or_scalar.hpp>
16+
#include <stan/math/prim/functor/partials_propagator.hpp>
17+
18+
namespace stan {
19+
namespace math {
20+
21+
/** \ingroup prob_dists
22+
* Returns the log PMF of the Generalized Linear Model (GLM)
23+
* with Binomial distribution and logit link function.
24+
* The idea is that binomial_logit_glm_lpmf(n | N, x, alpha, beta) should
25+
* compute a more efficient version of
26+
* binomial_logit_lpmf(y | N, alpha + x * beta) by using analytically
27+
* simplified gradients.
28+
* If containers are supplied, returns the log sum of the probabilities.
29+
*
30+
* @tparam T_n type of binary vector of successes variables;
31+
* this can also be a single binary value;
32+
* @tparam T_N type of binary vector of population size variables;
33+
* this can also be a single binary value;
34+
* @tparam T_x type of the matrix of independent variables (features)
35+
* @tparam T_alpha type of the intercept(s);
36+
* this can be a vector (of the same length as y) of intercepts or a single
37+
* value (for models with constant intercept);
38+
* @tparam T_beta type of the weight vector
39+
*
40+
* @param n binary scalar or vector parameter. If it is a scalar it will be
41+
* broadcast - used for all instances.
42+
* @param N binary scalar or vector parameter. If it is a scalar it will be
43+
* broadcast - used for all instances.
44+
* @param x design matrix or row vector. If it is a row vector it will be
45+
* broadcast - used for all instances.
46+
* @param alpha intercept
47+
* @param beta weight vector
48+
* @return log probability or log sum of probabilities
49+
* @throw std::domain_error if x, beta or alpha is infinite.
50+
* @throw std::domain_error if n is negative or greater than N
51+
* @throw std::domain_error if N is negative
52+
* @throw std::invalid_argument if container sizes mismatch.
53+
*/
54+
template <bool propto, typename T_n, typename T_N, typename T_x,
55+
typename T_alpha, typename T_beta, require_matrix_t<T_x>* = nullptr>
56+
return_type_t<T_x, T_alpha, T_beta> binomial_logit_glm_lpmf(
57+
const T_n& n, const T_N& N, const T_x& x, const T_alpha& alpha,
58+
const T_beta& beta) {
59+
constexpr int T_x_rows = T_x::RowsAtCompileTime;
60+
using T_xbeta_partials = partials_return_t<T_x, T_beta>;
61+
using T_partials_return = partials_return_t<T_x, T_alpha, T_beta>;
62+
using T_xbeta_tmp =
63+
typename std::conditional_t<T_x_rows == 1, T_xbeta_partials,
64+
Eigen::Array<T_xbeta_partials, -1, 1>>;
65+
using T_n_ref = ref_type_if_t<!is_constant<T_n>::value, T_n>;
66+
using T_N_ref = ref_type_if_t<!is_constant<T_N>::value, T_N>;
67+
using T_x_ref = ref_type_if_t<!is_constant<T_x>::value, T_x>;
68+
using T_alpha_ref = ref_type_if_t<!is_constant<T_alpha>::value, T_alpha>;
69+
using T_beta_ref = ref_type_if_t<!is_constant<T_beta>::value, T_beta>;
70+
71+
T_n_ref n_ref = n;
72+
T_N_ref N_ref = N;
73+
T_x_ref x_ref = x;
74+
T_alpha_ref alpha_ref = alpha;
75+
T_beta_ref beta_ref = beta;
76+
77+
if (size_zero(n, N, alpha, beta, x)) {
78+
return 0;
79+
}
80+
81+
if (!include_summand<propto, T_x, T_alpha, T_beta>::value) {
82+
return 0;
83+
}
84+
85+
const size_t N_instances = max_size(n, N, x.col(0), alpha);
86+
const size_t N_attributes = x.cols();
87+
88+
static const char* function = "binomial_logit_glm_lpmf";
89+
check_consistent_sizes(function, "Successes variable", n,
90+
"Population size parameter", N);
91+
check_consistent_size(function, "Successes variable", n, N_instances);
92+
check_consistent_size(function, "Population size parameter", N, N_instances);
93+
check_consistent_size(function, "Weight vector", beta, N_attributes);
94+
check_consistent_size(function, "Vector of intercepts", alpha, N_instances);
95+
96+
auto&& n_val = as_value_column_array_or_scalar(n_ref);
97+
auto&& N_val = as_value_column_array_or_scalar(N_ref);
98+
99+
check_bounded(function, "Successes variable", n_val, 0, N_val);
100+
check_nonnegative(function, "Population size parameter", N_val);
101+
102+
auto&& alpha_val = as_value_column_array_or_scalar(alpha_ref);
103+
auto&& beta_val = as_value_column_vector_or_scalar(beta_ref);
104+
auto&& x_val = value_of(x_ref);
105+
Eigen::Array<T_partials_return, -1, 1> theta(N_instances);
106+
if (T_x_rows == 1) {
107+
theta = forward_as<T_xbeta_tmp>((x_val * beta_val)(0, 0)) + alpha_val;
108+
} else {
109+
theta = (x_val * beta_val).array() + alpha_val;
110+
}
111+
112+
constexpr bool gradients_calc = !is_constant_all<T_beta, T_x, T_alpha>::value;
113+
auto&& log_inv_logit_theta = to_ref_if<gradients_calc>(log_inv_logit(theta));
114+
115+
T_partials_return logp = sum(n_val * log_inv_logit_theta
116+
+ (N_val - n_val) * log1m_inv_logit(theta));
117+
118+
using std::isfinite;
119+
if (!isfinite(logp)) {
120+
check_finite(function, "Weight vector", beta);
121+
check_finite(function, "Intercept", alpha);
122+
check_finite(function, "Matrix of independent variables", x);
123+
}
124+
125+
if (include_summand<propto, T_n, T_N>::value) {
126+
size_t broadcast_n = max_size(N, n) == N_instances ? 1 : N_instances;
127+
logp += sum(binomial_coefficient_log(N_val, n_val)) * broadcast_n;
128+
}
129+
130+
auto ops_partials = make_partials_propagator(x_ref, alpha_ref, beta_ref);
131+
if (gradients_calc) {
132+
Eigen::Matrix<T_partials_return, -1, 1> theta_derivative
133+
= n_val - N_val * exp(log_inv_logit_theta);
134+
135+
if (!is_constant_all<T_beta>::value) {
136+
if (T_x_rows == 1) {
137+
edge<2>(ops_partials).partials_
138+
= forward_as<Eigen::Matrix<T_partials_return, 1, -1>>(
139+
theta_derivative.sum() * x_val);
140+
} else {
141+
partials<2>(ops_partials) = x_val.transpose() * theta_derivative;
142+
}
143+
}
144+
145+
if (!is_constant_all<T_x>::value) {
146+
if (T_x_rows == 1) {
147+
edge<0>(ops_partials).partials_
148+
= forward_as<Eigen::Array<T_partials_return, -1, T_x_rows>>(
149+
beta_val * theta_derivative.sum());
150+
} else {
151+
edge<0>(ops_partials).partials_
152+
= (beta_val * theta_derivative.transpose()).transpose();
153+
}
154+
}
155+
if (!is_constant_all<T_alpha>::value) {
156+
partials<1>(ops_partials) = theta_derivative;
157+
}
158+
}
159+
return ops_partials.build(logp);
160+
}
161+
162+
template <typename T_n, typename T_N, typename T_x, typename T_alpha,
163+
typename T_beta>
164+
inline return_type_t<T_x, T_beta, T_alpha> binomial_logit_glm_lpmf(
165+
const T_n& n, const T_N& N, const T_x& x, const T_alpha& alpha,
166+
const T_beta& beta) {
167+
return binomial_logit_glm_lpmf<false>(n, N, x, alpha, beta);
168+
}
169+
} // namespace math
170+
} // namespace stan
171+
#endif
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include <stan/math/mix.hpp>
2+
#include <test/unit/math/test_ad.hpp>
3+
4+
TEST(mathMixScalFun, binomial_logit_glm_lpmf) {
5+
auto f = [](const auto n, const auto N) {
6+
return [=](const auto& x, const auto& alpha, const auto& beta) {
7+
return stan::math::binomial_logit_glm_lpmf(n, N, x, alpha, beta);
8+
};
9+
};
10+
11+
std::vector<int> n_arr{1, 4};
12+
std::vector<int> N_arr{10, 45};
13+
Eigen::MatrixXd x = Eigen::MatrixXd::Random(2, 2);
14+
Eigen::RowVectorXd x_rowvec = x.row(0);
15+
Eigen::VectorXd alpha = Eigen::VectorXd::Random(2);
16+
Eigen::VectorXd beta = Eigen::VectorXd::Random(2);
17+
18+
stan::test::expect_ad(f(n_arr[0], N_arr[0]), x, alpha, beta);
19+
stan::test::expect_ad(f(n_arr[0], N_arr), x, alpha, beta);
20+
stan::test::expect_ad(f(n_arr, N_arr[0]), x, alpha, beta);
21+
stan::test::expect_ad(f(n_arr, N_arr), x, alpha, beta);
22+
stan::test::expect_ad(f(n_arr[0], N_arr[0]), x, alpha[0], beta);
23+
stan::test::expect_ad(f(n_arr[0], N_arr), x, alpha[0], beta);
24+
stan::test::expect_ad(f(n_arr, N_arr[0]), x, alpha[0], beta);
25+
stan::test::expect_ad(f(n_arr, N_arr), x, alpha[0], beta);
26+
stan::test::expect_ad(f(n_arr[0], N_arr[0]), x_rowvec, alpha, beta);
27+
stan::test::expect_ad(f(n_arr[0], N_arr), x_rowvec, alpha, beta);
28+
stan::test::expect_ad(f(n_arr, N_arr[0]), x_rowvec, alpha, beta);
29+
stan::test::expect_ad(f(n_arr, N_arr), x_rowvec, alpha, beta);
30+
stan::test::expect_ad(f(n_arr[0], N_arr[0]), x_rowvec, alpha[0], beta);
31+
stan::test::expect_ad(f(n_arr[0], N_arr), x_rowvec, alpha[0], beta);
32+
stan::test::expect_ad(f(n_arr, N_arr[0]), x_rowvec, alpha[0], beta);
33+
stan::test::expect_ad(f(n_arr, N_arr), x_rowvec, alpha[0], beta);
34+
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#include <stan/math/prim.hpp>
2+
#include <gtest/gtest.h>
3+
4+
TEST(ProbBinomialLogitGLM, matchesNonGLM) {
5+
using stan::math::binomial_logit_glm_lpmf;
6+
using stan::math::binomial_logit_lpmf;
7+
8+
std::vector<int> n{1, 2};
9+
std::vector<int> N{5, 4};
10+
Eigen::MatrixXd x = Eigen::MatrixXd::Random(2, 2);
11+
Eigen::RowVectorXd x_row = x.row(0);
12+
Eigen::VectorXd alpha = Eigen::VectorXd::Random(2);
13+
Eigen::VectorXd beta = Eigen::VectorXd::Random(2);
14+
15+
Eigen::VectorXd theta = alpha + x * beta;
16+
17+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n, N, theta),
18+
binomial_logit_glm_lpmf(n, N, x, alpha, beta));
19+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n[0], N, theta),
20+
binomial_logit_glm_lpmf(n[0], N, x, alpha, beta));
21+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n, N[0], theta),
22+
binomial_logit_glm_lpmf(n, N[0], x, alpha, beta));
23+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n[0], N[0], theta),
24+
binomial_logit_glm_lpmf(n[0], N[0], x, alpha, beta));
25+
26+
theta = (alpha[0] + (x * beta).array()).matrix();
27+
28+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n, N, theta),
29+
binomial_logit_glm_lpmf(n, N, x, alpha[0], beta));
30+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n[0], N, theta),
31+
binomial_logit_glm_lpmf(n[0], N, x, alpha[0], beta));
32+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n, N[0], theta),
33+
binomial_logit_glm_lpmf(n, N[0], x, alpha[0], beta));
34+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n[0], N[0], theta),
35+
binomial_logit_glm_lpmf(n[0], N[0], x, alpha[0], beta));
36+
37+
theta = (alpha.array() + (x_row * beta)(0, 0)).matrix();
38+
39+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n, N, theta),
40+
binomial_logit_glm_lpmf(n, N, x_row, alpha, beta));
41+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n[0], N, theta),
42+
binomial_logit_glm_lpmf(n[0], N, x_row, alpha, beta));
43+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n, N[0], theta),
44+
binomial_logit_glm_lpmf(n, N[0], x_row, alpha, beta));
45+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n[0], N[0], theta),
46+
binomial_logit_glm_lpmf(n[0], N[0], x_row, alpha, beta));
47+
}
48+
49+
TEST(ProbBinomialLogitGLM, throwsCorrectly) {
50+
using stan::math::binomial_logit_glm_lpmf;
51+
using stan::math::INFTY;
52+
53+
std::vector<int> n{1, 2};
54+
std::vector<int> N{5, 4};
55+
Eigen::MatrixXd x = Eigen::MatrixXd::Random(2, 2);
56+
Eigen::VectorXd alpha = Eigen::VectorXd::Random(2);
57+
Eigen::VectorXd beta = Eigen::VectorXd::Random(2);
58+
59+
std::vector<int> N_mismatch_size{5, 4, 10};
60+
EXPECT_THROW(binomial_logit_glm_lpmf(n, N_mismatch_size, x, alpha, beta),
61+
std::invalid_argument);
62+
EXPECT_THROW(binomial_logit_glm_lpmf(500, 1, x, alpha, beta),
63+
std::domain_error);
64+
EXPECT_THROW(binomial_logit_glm_lpmf(-10, N, x, alpha, beta),
65+
std::domain_error);
66+
EXPECT_THROW(binomial_logit_glm_lpmf(n, -10, x, alpha, beta),
67+
std::domain_error);
68+
69+
Eigen::VectorXd alpha_inf = alpha;
70+
alpha[0] = INFTY;
71+
Eigen::VectorXd beta_inf = beta;
72+
beta[0] = INFTY;
73+
Eigen::MatrixXd x_inf = x;
74+
x(0, 0) = INFTY;
75+
76+
EXPECT_THROW(binomial_logit_glm_lpmf(n, N, x_inf, alpha, beta),
77+
std::domain_error);
78+
EXPECT_THROW(binomial_logit_glm_lpmf(n, N, x, alpha_inf, beta),
79+
std::domain_error);
80+
EXPECT_THROW(binomial_logit_glm_lpmf(n, N, x, alpha, beta_inf),
81+
std::domain_error);
82+
}

0 commit comments

Comments
 (0)