Skip to content

Commit 17a6ec9

Browse files
committed
fixes
1 parent 4900882 commit 17a6ec9

5 files changed

Lines changed: 73 additions & 17 deletions

File tree

R/clean_names.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
# Clean names -------------------------------------------------------------
22

3-
43
#' @keywords internal
54
.clean_names_frequentist <- function(means, predict = NULL, info = NULL) {
6-
names(means)[names(means) == "emmean"] <- .guess_estimate_name(predict, info)
7-
names(means)[names(means) == "response"] <- .guess_estimate_name(predict, info)
5+
names(means)[names(means) == "emmean"] <- .guess_estimate_name(predict, info = info)
6+
names(means)[names(means) == "response"] <- .guess_estimate_name(predict, info = info)
87
names(means)[names(means) == "prob"] <- "Probability"
98
names(means)[names(means) == "estimate"] <- "Difference"
109
names(means)[names(means) == "odds.ratio"] <- "Odds_ratio"

R/format.R

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ format.marginaleffects_means <- function(x, model, ci = 0.95, ...) {
154154
}
155155
non_focal <- setdiff(colnames(model_data), attr(x, "focal_terms"))
156156
predict_type <- attributes(x)$predict
157+
transform <- attributes(x)$transform
157158

158159
# special attributes we get from "get_marginalcontrasts()"
159160
comparison <- list(...)$hypothesis
@@ -178,7 +179,7 @@ format.marginaleffects_means <- function(x, model, ci = 0.95, ...) {
178179
# for simple means, we don't want p-values
179180
remove_columns <- c(remove_columns, "p")
180181
# estimate name
181-
estimate_name <- .guess_estimate_name(predict_type, info)
182+
estimate_name <- .guess_estimate_name(predict_type, transform, info)
182183
}
183184

184185
# reshape and format columns
@@ -834,12 +835,10 @@ equivalence_columns <- c(
834835
# based on on which scale predictions were requested
835836

836837
#' @keywords internal
837-
.guess_estimate_name <- function(predict_type, info) {
838+
.guess_estimate_name <- function(predict_type, transform = NULL, info) {
838839
# estimate name
839840
if (is.null(predict_type)) {
840841
estimate_name <- "Mean"
841-
} else if (predict_type == "context") {
842-
estimate_name <- "Estimate"
843842
} else if (tolower(predict_type) %in% .brms_aux_elements()) {
844843
# for Bayesian models with distributional parameter
845844
estimate_name <- tools::toTitleCase(predict_type)
@@ -850,6 +849,19 @@ equivalence_columns <- c(
850849
# here we add all models that model the probability of an outcome, such as
851850
# binomial, multinomial, or Bernoulli models
852851
estimate_name <- "Probability"
852+
} else if (
853+
predict_type %in%
854+
c("none", "link") &&
855+
identical(transform, "exp") &&
856+
(info$is_binomial || info$is_bernoulli || info$is_multinomial)
857+
) {
858+
# here we add all models that have odds ratios as exponentiated coefficients
859+
estimate_name <- "Odds_Ratio"
860+
} else if (
861+
predict_type %in% c("none", "link") && identical(transform, "exp") && (info$is_count)
862+
) {
863+
# here we add all models that have IRRs as exponentiated coefficients
864+
estimate_name <- "IRR"
853865
} else if (predict_type == "survival" && info$is_survival) {
854866
# this is for survival models, where we want to predict the survival probability
855867
estimate_name <- "Probability"

R/get_contexteffects.R

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
# special contrasts: context effects ----------------------------------------
22
# ---------------------------------------------------------------------------
33

4-
.get_contexteffects <- function(model, my_args, model_info, ...) {
4+
.get_contexteffects <- function(
5+
model,
6+
my_args,
7+
predict = NULL,
8+
transform = NULL,
9+
model_info,
10+
...
11+
) {
512
if (model_info$is_linear) {
613
out <- marginaleffects::avg_comparisons(
714
model,
@@ -13,14 +20,17 @@
1320
dots <- list(...)
1421
fun_args <- list(model, variables = my_args$contrast, hypothesis = my_args$comparison)
1522
# set default for "type" argument, if not provided
16-
if (is.null(dots$type)) {
23+
if (is.null(predict)) {
1724
fun_args$type <- "link"
1825
# if "type" was not provided, also change transform argument. we do
1926
# this only when user did not provide "type", else - if user provided
2027
# "type" - we keep the default NULL
21-
if (is.null(dots$transform)) {
28+
if (is.null(transform)) {
2229
fun_args$transform <- "exp"
2330
}
31+
} else {
32+
fun_args$type <- predict
33+
fun_args$transform <- transform
2434
}
2535
out <- do.call(marginaleffects::avg_comparisons, c(fun_args, dots))
2636
}

R/get_marginalcontrasts.R

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,14 @@ get_marginalcontrasts <- function(
102102
)
103103
predict <- "response"
104104
} else if (isTRUE(my_args$context_effects)) {
105-
out <- .get_contexteffects(model, my_args, model_info, ...)
106-
predict <- "context"
105+
out <- .get_contexteffects(model, my_args, predict, transform, model_info, ...)
106+
# set defaults, for proper printing
107+
if (is.null(predict)) {
108+
predict <- "link"
109+
if (is.null(transform)) {
110+
transform <- "exp"
111+
}
112+
}
107113
} else if (compute_slopes) {
108114
# sanity check - contrast for slopes only makes sense when we have a "by" argument
109115
if (is.null(my_args$by)) {
@@ -159,6 +165,7 @@ get_marginalcontrasts <- function(
159165
info = list(
160166
contrast = my_args$contrast,
161167
predict = predict,
168+
transform = transform,
162169
comparison = my_args$comparison,
163170
estimate = estimate,
164171
p_adjust = p_adjust,

R/table_footer.R

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,23 @@
2222
if (isTRUE(info$joint_test)) {
2323
table_footer <- NULL
2424
} else {
25-
table_footer <- paste0("\nVariable predicted: ", toString(insight::find_response(model)))
25+
table_footer <- paste0(
26+
"\nVariable predicted: ",
27+
toString(insight::find_response(model))
28+
)
2629
}
2730

2831
# modulated predictors (focal terms) ---------------------------------------
2932

3033
if (!is.null(by) && !isTRUE(info$joint_test)) {
3134
modulate_string <- switch(type, inequality = , contrasts = "contrasted", "modulated")
32-
table_footer <- paste0(table_footer, "\nPredictors ", modulate_string, ": ", toString(by))
35+
table_footer <- paste0(
36+
table_footer,
37+
"\nPredictors ",
38+
modulate_string,
39+
": ",
40+
toString(by)
41+
)
3342
}
3443

3544
# predictors controlled (non-focal terms) ----------------------------------
@@ -49,7 +58,11 @@
4958
# over the list, because we may have different types of data
5059
for (av in seq_along(adjusted_values)) {
5160
if (is.numeric(adjusted_values[[av]])) {
52-
adjusted_for[av] <- sprintf("%s (%.2g)", adjusted_for[av], adjusted_values[[av]])
61+
adjusted_for[av] <- sprintf(
62+
"%s (%.2g)",
63+
adjusted_for[av],
64+
adjusted_values[[av]]
65+
)
5366
} else if (identical(type, "predictions")) {
5467
adjusted_for[av] <- sprintf("%s (%s)", adjusted_for[av], adjusted_values[[av]])
5568
}
@@ -87,7 +100,12 @@
87100

88101
# tell user about scale of predictions / contrasts -------------------------
89102

90-
result_type <- switch(type, inequality = "Differences", contrasts = "Contrasts", "Predictions")
103+
result_type <- switch(
104+
type,
105+
inequality = "Differences",
106+
contrasts = "Contrasts",
107+
"Predictions"
108+
)
91109

92110
if (!is.null(predict) && isFALSE(model_info$is_linear)) {
93111
# exceptions
@@ -99,6 +117,12 @@
99117
`invlink(link)` = "response",
100118
predict
101119
)
120+
## TODO: simplification, we just mention it is transformed; we could check
121+
## model info and then handle different cases, like odds ratios or IRRs etc.
122+
## See `.guess_estimate_name()`
123+
if (!is.null(transform)) {
124+
predict <- "transformed"
125+
}
102126
table_footer <- paste0(
103127
table_footer,
104128
"\n",
@@ -141,7 +165,11 @@
141165
hypothesis_labels <- unlist(
142166
lapply(parameter_names, function(i) {
143167
rows <- as.numeric(sub(".", "", i))
144-
paste0(i, " = ", toString(paste0(info$focal_terms, " [", transposed_dg[, rows], "]")))
168+
paste0(
169+
i,
170+
" = ",
171+
toString(paste0(info$focal_terms, " [", transposed_dg[, rows], "]"))
172+
)
145173
}),
146174
use.names = FALSE
147175
)

0 commit comments

Comments
 (0)