Skip to content

Commit dcc5240

Browse files
committed
Add implementation and tests
1 parent e5879f6 commit dcc5240

4 files changed

Lines changed: 292 additions & 0 deletions

File tree

stan/math/prim/prob.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include <stan/math/prim/prob/binomial_log.hpp>
4747
#include <stan/math/prim/prob/binomial_logit_log.hpp>
4848
#include <stan/math/prim/prob/binomial_logit_lpmf.hpp>
49+
#include <stan/math/prim/prob/binomial_logit_glm_lpmf.hpp>
4950
#include <stan/math/prim/prob/binomial_lpmf.hpp>
5051
#include <stan/math/prim/prob/binomial_rng.hpp>
5152
#include <stan/math/prim/prob/categorical_log.hpp>
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
#ifndef STAN_MATH_PRIM_PROB_BINOMIAL_LOGIT_GLM_LPMF_HPP
2+
#define STAN_MATH_PRIM_PROB_BINOMIAL_LOGIT_GLM_LPMF_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/Eigen.hpp>
7+
#include <stan/math/prim/fun/isfinite.hpp>
8+
#include <stan/math/prim/fun/size.hpp>
9+
#include <stan/math/prim/fun/size_zero.hpp>
10+
#include <stan/math/prim/fun/max_size.hpp>
11+
#include <stan/math/prim/fun/max.hpp>
12+
#include <stan/math/prim/fun/to_ref.hpp>
13+
#include <stan/math/prim/fun/value_of.hpp>
14+
#include <stan/math/prim/fun/as_value_column_array_or_scalar.hpp>
15+
#include <stan/math/prim/fun/as_value_column_vector_or_scalar.hpp>
16+
17+
namespace stan {
18+
namespace math {
19+
20+
/** \ingroup multivar_dists
21+
* Returns the log PMF of the Generalized Linear Model (GLM)
22+
* with Binomial distribution and logit link function.
23+
* The idea is that binomial_logit_glm_lpmf(n | N, x, alpha, beta) should
24+
* compute a more efficient version of
25+
* binomial_logit_lpmf(y | N, alpha + x * beta) by using analytically
26+
* simplified gradients.
27+
* If containers are supplied, returns the log sum of the probabilities.
28+
*
29+
* @tparam T_y type of binary vector of successes variables;
30+
* this can also be a single binary value;
31+
* @tparam T_y type of binary vector of population size variables;
32+
* this can also be a single binary value;
33+
* @tparam T_x type of the matrix of independent variables (features)
34+
* @tparam T_alpha type of the intercept(s);
35+
* this can be a vector (of the same length as y) of intercepts or a single
36+
* value (for models with constant intercept);
37+
* @tparam T_beta type of the weight vector
38+
*
39+
* @param n binary scalar or vector parameter. If it is a scalar it will be
40+
* broadcast - used for all instances.
41+
* @param N binary scalar or vector parameter. If it is a scalar it will be
42+
* broadcast - used for all instances.
43+
* @param x design matrix or row vector. If it is a row vector it will be
44+
* broadcast - used for all instances.
45+
* @param alpha intercept (in log odds)
46+
* @param beta weight vector
47+
* @return log probability or log sum of probabilities
48+
* @throw std::domain_error if x, beta or alpha is infinite.
49+
* @throw std::domain_error if n is negative or greater than N
50+
* @throw std::domain_error if N is negative
51+
* @throw std::invalid_argument if container sizes mismatch.
52+
*/
53+
template <bool propto, typename T_n, typename T_N, typename T_x,
54+
typename T_alpha, typename T_beta, require_matrix_t<T_x>* = nullptr>
55+
return_type_t<T_x, T_alpha, T_beta> binomial_logit_glm_lpmf(
56+
const T_n& n, const T_N& N, const T_x& x,
57+
const T_alpha& alpha, const T_beta& beta) {
58+
constexpr int T_x_rows = T_x::RowsAtCompileTime;
59+
using T_xbeta_partials = partials_return_t<T_x, T_beta>;
60+
using T_partials_return = partials_return_t<T_x, T_alpha, T_beta>;
61+
using T_theta_tmp =
62+
typename std::conditional_t<T_x_rows == 1, T_partials_return,
63+
Eigen::Array<T_partials_return, -1, 1>>;
64+
using T_xbeta_tmp =
65+
typename std::conditional_t<T_x_rows == 1, T_xbeta_partials,
66+
Eigen::Array<T_xbeta_partials, -1, 1>>;
67+
using T_n_ref = ref_type_if_t<!is_constant<T_n>::value, T_n>;
68+
using T_N_ref = ref_type_if_t<!is_constant<T_N>::value, T_N>;
69+
using T_x_ref = ref_type_if_t<!is_constant<T_x>::value, T_x>;
70+
using T_alpha_ref = ref_type_if_t<!is_constant<T_alpha>::value, T_alpha>;
71+
using T_beta_ref = ref_type_if_t<!is_constant<T_beta>::value, T_beta>;
72+
73+
T_n_ref n_ref = n;
74+
T_N_ref N_ref = N;
75+
T_x_ref x_ref = x;
76+
T_alpha_ref alpha_ref = alpha;
77+
T_beta_ref beta_ref = beta;
78+
79+
if (size_zero(n, N, alpha, beta, x)) {
80+
return 0;
81+
}
82+
83+
if (!include_summand<propto, T_x, T_alpha, T_beta>::value) {
84+
return 0;
85+
}
86+
87+
const size_t N_instances = max_size(n, N, x.col(0), alpha);
88+
const size_t N_attributes = x.cols();
89+
90+
static const char* function = "binomial_logit_glm_lpmf";
91+
check_consistent_sizes(function, "Successes variable", n,
92+
"Population size parameter", N);
93+
check_consistent_size(function, "Successes variable", n, N_instances);
94+
check_consistent_size(function, "Population size parameter", N, N_instances);
95+
check_consistent_size(function, "Weight vector", beta, N_attributes);
96+
check_consistent_size(function, "Vector of intercepts", alpha, N_instances);
97+
98+
const auto& n_val = as_value_column_array_or_scalar(n_ref);
99+
const auto& N_val = as_value_column_array_or_scalar(N_ref);
100+
101+
check_bounded(function, "Successes variable", n_val, 0, N_val);
102+
check_nonnegative(function, "Population size parameter", N_val);
103+
104+
const auto& alpha_val = as_value_column_array_or_scalar(alpha_ref);
105+
const auto& beta_val = as_value_column_vector_or_scalar(beta_ref);
106+
const auto& x_val = value_of(x_ref);
107+
Eigen::Array<T_partials_return, -1, 1> theta(N_instances);
108+
if (T_x_rows == 1) {
109+
theta = forward_as<T_xbeta_tmp>((x_val * beta_val)(0, 0)) + alpha_val;
110+
} else {
111+
theta = (x_val * beta_val).array() + alpha_val;
112+
}
113+
114+
const auto& log_inv_logit_theta = log_inv_logit(theta);
115+
const auto& log1m_inv_logit_theta = log1m_inv_logit(theta);
116+
117+
T_partials_return logp = sum(n_val * log_inv_logit_theta
118+
+ (N_val - n_val) * log1m_inv_logit_theta);
119+
120+
using std::isfinite;
121+
if (!isfinite(logp)) {
122+
check_finite(function, "Weight vector", beta);
123+
check_finite(function, "Intercept", alpha);
124+
check_finite(function, "Matrix of independent variables", x);
125+
}
126+
127+
if (include_summand<propto, T_n, T_N>::value) {
128+
size_t broadcast_n = max_size(N, n) == N_instances ? 1 : N_instances;
129+
logp += sum(binomial_coefficient_log(N_val, n_val)) * broadcast_n;
130+
}
131+
132+
auto ops_partials = make_partials_propagator(x_ref, alpha_ref, beta_ref);
133+
134+
if (!is_constant_all<T_beta, T_x, T_alpha>::value) {
135+
Eigen::Matrix<T_partials_return, -1, 1> theta_derivative
136+
= n_val - N_val * exp(log_inv_logit_theta);
137+
138+
if (!is_constant_all<T_beta>::value) {
139+
if (T_x_rows == 1) {
140+
edge<2>(ops_partials).partials_
141+
= forward_as<Eigen::Matrix<T_partials_return, 1, -1>>(
142+
theta_derivative.sum() * x_val);
143+
} else {
144+
partials<2>(ops_partials) = x_val.transpose() * theta_derivative;
145+
}
146+
}
147+
148+
if (!is_constant_all<T_x>::value) {
149+
if (T_x_rows == 1) {
150+
edge<0>(ops_partials).partials_
151+
= forward_as<Eigen::Array<T_partials_return, -1, T_x_rows>>(
152+
beta_val * theta_derivative.sum());
153+
} else {
154+
edge<0>(ops_partials).partials_
155+
= (beta_val * theta_derivative.transpose()).transpose();
156+
}
157+
}
158+
if (!is_constant_all<T_alpha>::value) {
159+
partials<1>(ops_partials) = theta_derivative;
160+
}
161+
}
162+
return ops_partials.build(logp);
163+
}
164+
165+
template <typename T_n, typename T_N, typename T_x, typename T_alpha,
166+
typename T_beta>
167+
inline return_type_t<T_x, T_beta, T_alpha> binomial_logit_glm_lpmf(
168+
const T_n& n, const T_N& N, const T_x& x,
169+
const T_alpha& alpha, const T_beta& beta) {
170+
return binomial_logit_glm_lpmf<false>(n, N, x, alpha, beta);
171+
}
172+
} // namespace math
173+
} // namespace stan
174+
#endif
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#include <stan/math/mix.hpp>
2+
#include <test/unit/math/test_ad.hpp>
3+
4+
TEST(mathMixScalFun, binomial_logit_glm_lpmf) {
5+
auto f = [](const auto n, const auto N) {
6+
return [=](const auto& x, const auto& alpha, const auto& beta) {
7+
return stan::math::binomial_logit_glm_lpmf(n, N, x, alpha, beta);
8+
};
9+
};
10+
11+
std::vector<int> n_arr{1, 4};
12+
std::vector<int> N_arr{10, 45};
13+
Eigen::MatrixXd x = Eigen::MatrixXd::Random(2, 2);
14+
Eigen::RowVectorXd x_rowvec = x.row(0);
15+
Eigen::VectorXd alpha = Eigen::VectorXd::Random(2);
16+
Eigen::VectorXd beta = Eigen::VectorXd::Random(2);
17+
18+
stan::test::expect_ad(f(n_arr[0], N_arr[0]), x, alpha, beta);
19+
stan::test::expect_ad(f(n_arr[0], N_arr), x, alpha, beta);
20+
stan::test::expect_ad(f(n_arr, N_arr[0]), x, alpha, beta);
21+
stan::test::expect_ad(f(n_arr, N_arr), x, alpha, beta);
22+
stan::test::expect_ad(f(n_arr[0], N_arr[0]), x, alpha[0], beta);
23+
stan::test::expect_ad(f(n_arr[0], N_arr), x, alpha[0], beta);
24+
stan::test::expect_ad(f(n_arr, N_arr[0]), x, alpha[0], beta);
25+
stan::test::expect_ad(f(n_arr, N_arr), x, alpha[0], beta);
26+
stan::test::expect_ad(f(n_arr[0], N_arr[0]), x_rowvec, alpha, beta);
27+
stan::test::expect_ad(f(n_arr[0], N_arr), x_rowvec, alpha, beta);
28+
stan::test::expect_ad(f(n_arr, N_arr[0]), x_rowvec, alpha, beta);
29+
stan::test::expect_ad(f(n_arr, N_arr), x_rowvec, alpha, beta);
30+
stan::test::expect_ad(f(n_arr[0], N_arr[0]), x_rowvec, alpha[0], beta);
31+
stan::test::expect_ad(f(n_arr[0], N_arr), x_rowvec, alpha[0], beta);
32+
stan::test::expect_ad(f(n_arr, N_arr[0]), x_rowvec, alpha[0], beta);
33+
stan::test::expect_ad(f(n_arr, N_arr), x_rowvec, alpha[0], beta);
34+
35+
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#include <stan/math/prim.hpp>
2+
#include <gtest/gtest.h>
3+
4+
TEST(ProbBinomialLogitGLM, matchesNonGLM) {
5+
using stan::math::binomial_logit_lpmf;
6+
using stan::math::binomial_logit_glm_lpmf;
7+
8+
std::vector<int> n{1, 2};
9+
std::vector<int> N{5, 4};
10+
Eigen::MatrixXd x = Eigen::MatrixXd::Random(2, 2);
11+
Eigen::RowVectorXd x_row = x.row(0);
12+
Eigen::VectorXd alpha = Eigen::VectorXd::Random(2);
13+
Eigen::VectorXd beta = Eigen::VectorXd::Random(2);
14+
15+
Eigen::VectorXd theta = alpha + x * beta;
16+
17+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n, N, theta),
18+
binomial_logit_glm_lpmf(n, N, x, alpha, beta));
19+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n[0], N, theta),
20+
binomial_logit_glm_lpmf(n[0], N, x, alpha, beta));
21+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n, N[0], theta),
22+
binomial_logit_glm_lpmf(n, N[0], x, alpha, beta));
23+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n[0], N[0], theta),
24+
binomial_logit_glm_lpmf(n[0], N[0], x, alpha, beta));
25+
26+
theta = (alpha[0] + (x * beta).array()).matrix();
27+
28+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n, N, theta),
29+
binomial_logit_glm_lpmf(n, N, x, alpha[0], beta));
30+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n[0], N, theta),
31+
binomial_logit_glm_lpmf(n[0], N, x, alpha[0], beta));
32+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n, N[0], theta),
33+
binomial_logit_glm_lpmf(n, N[0], x, alpha[0], beta));
34+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n[0], N[0], theta),
35+
binomial_logit_glm_lpmf(n[0], N[0], x, alpha[0], beta));
36+
37+
theta = (alpha.array() + (x_row * beta)(0, 0)).matrix();
38+
39+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n, N, theta),
40+
binomial_logit_glm_lpmf(n, N, x_row, alpha, beta));
41+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n[0], N, theta),
42+
binomial_logit_glm_lpmf(n[0], N, x_row, alpha, beta));
43+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n, N[0], theta),
44+
binomial_logit_glm_lpmf(n, N[0], x_row, alpha, beta));
45+
EXPECT_FLOAT_EQ(binomial_logit_lpmf(n[0], N[0], theta),
46+
binomial_logit_glm_lpmf(n[0], N[0], x_row, alpha, beta));
47+
}
48+
49+
TEST(ProbBinomialLogitGLM, throwsCorrectly) {
50+
using stan::math::binomial_logit_glm_lpmf;
51+
using stan::math::INFTY;
52+
53+
std::vector<int> n{1, 2};
54+
std::vector<int> N{5, 4};
55+
Eigen::MatrixXd x = Eigen::MatrixXd::Random(2, 2);
56+
Eigen::VectorXd alpha = Eigen::VectorXd::Random(2);
57+
Eigen::VectorXd beta = Eigen::VectorXd::Random(2);
58+
59+
std::vector<int> N_mismatch_size{5, 4, 10};
60+
EXPECT_THROW(binomial_logit_glm_lpmf(n, N_mismatch_size, x, alpha, beta),
61+
std::invalid_argument);
62+
EXPECT_THROW(binomial_logit_glm_lpmf(500, 1, x, alpha, beta),
63+
std::domain_error);
64+
EXPECT_THROW(binomial_logit_glm_lpmf(-10, N, x, alpha, beta),
65+
std::domain_error);
66+
EXPECT_THROW(binomial_logit_glm_lpmf(n, -10, x, alpha, beta),
67+
std::domain_error);
68+
69+
Eigen::VectorXd alpha_inf = alpha;
70+
alpha[0] = INFTY;
71+
Eigen::VectorXd beta_inf = beta;
72+
beta[0] = INFTY;
73+
Eigen::MatrixXd x_inf = x;
74+
x(0, 0) = INFTY;
75+
76+
EXPECT_THROW(binomial_logit_glm_lpmf(n, N, x_inf, alpha, beta),
77+
std::domain_error);
78+
EXPECT_THROW(binomial_logit_glm_lpmf(n, N, x, alpha_inf, beta),
79+
std::domain_error);
80+
EXPECT_THROW(binomial_logit_glm_lpmf(n, N, x, alpha, beta_inf),
81+
std::domain_error);
82+
}

0 commit comments

Comments
 (0)