Skip to content

Commit bed77b6

Browse files
committed
unfinished work adding tests
1 parent e181bb2 commit bed77b6

5 files changed

Lines changed: 80 additions & 7 deletions

File tree

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ Imports:
2525
reshape2 (>= 1.4.2),
2626
rmarkdown (>= 1.5)
2727
Suggests:
28-
knitr
28+
knitr,
29+
testthat
2930
VignetteBuilder: knitr
3031
RoxygenNote: 6.1.1
3132
URL: https://github.com/ModelOriented/randomForestExplainer

R/min_depth_interactions.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ min_depth_interactions_values <- function(forest, vars){
6868
mean_tree_depth <- dplyr::group_by(interactions_frame[, c("tree", vars)], tree) %>%
6969
dplyr::summarize_at(vars, funs(max(., na.rm = TRUE))) %>% as.data.frame()
7070
mean_tree_depth[mean_tree_depth == -Inf] <- NA
71-
mean_tree_depth <- colMeans(mean_tree_depth[, vars], na.rm = TRUE)
71+
mean_tree_depth <- colMeans(mean_tree_depth[, vars, drop = FALSE], na.rm = TRUE)
7272
min_depth_interactions_frame <-
7373
interactions_frame %>% dplyr::group_by(tree, `split var`) %>%
7474
dplyr::summarize_at(vars, funs(min(., na.rm = TRUE))) %>% as.data.frame()
@@ -93,7 +93,7 @@ min_depth_interactions_values_ranger <- function(forest, vars){
9393
mean_tree_depth <- dplyr::group_by(interactions_frame[, c("tree", vars)], tree) %>%
9494
dplyr::summarize_at(vars, funs(max(., na.rm = TRUE))) %>% as.data.frame()
9595
mean_tree_depth[mean_tree_depth == -Inf] <- NA
96-
mean_tree_depth <- colMeans(mean_tree_depth[, vars], na.rm = TRUE)
96+
mean_tree_depth <- colMeans(mean_tree_depth[, vars, drop = FALSE], na.rm = TRUE)
9797
min_depth_interactions_frame <-
9898
interactions_frame %>% dplyr::group_by(tree, splitvarName) %>%
9999
dplyr::summarize_at(vars, funs(min(., na.rm = TRUE))) %>% as.data.frame()
@@ -146,15 +146,15 @@ min_depth_interactions.randomForest <- function(forest, vars = important_variabl
146146
non_occurrences[, -1] <- forest$ntree - occurrences[, -1]
147147
interactions_frame[is.na(as.matrix(interactions_frame))] <- 0
148148
interactions_frame[, -1] <- (interactions_frame[, -1] * occurrences[, -1] +
149-
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth))/forest$ntree
149+
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth, nrow = length(mean_tree_depth)))/forest$ntree
150150
} else if(mean_sample == "top_trees"){
151151
non_occurrences <- occurrences
152152
non_occurrences[, -1] <- forest$ntree - occurrences[, -1]
153153
minimum_non_occurrences <- min(non_occurrences[, -1])
154154
non_occurrences[, -1] <- non_occurrences[, -1] - minimum_non_occurrences
155155
interactions_frame[is.na(as.matrix(interactions_frame))] <- 0
156156
interactions_frame[, -1] <- (interactions_frame[, -1] * occurrences[, -1] +
157-
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth))/(forest$ntree - minimum_non_occurrences)
157+
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth, nrow = length(mean_tree_depth)))/(forest$ntree - minimum_non_occurrences)
158158
}
159159
interactions_frame <- reshape2::melt(interactions_frame, id.vars = "variable")
160160
colnames(interactions_frame)[2:3] <- c("root_variable", "mean_min_depth")
@@ -195,15 +195,15 @@ min_depth_interactions.ranger <- function(forest, vars = important_variables(mea
195195
non_occurrences[, -1] <- forest$num.trees - occurrences[, -1]
196196
interactions_frame[is.na(as.matrix(interactions_frame))] <- 0
197197
interactions_frame[, -1] <- (interactions_frame[, -1] * occurrences[, -1] +
198-
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth))/forest$num.trees
198+
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth, nrow = length(mean_tree_depth)))/forest$num.trees
199199
} else if(mean_sample == "top_trees"){
200200
non_occurrences <- occurrences
201201
non_occurrences[, -1] <- forest$num.trees - occurrences[, -1]
202202
minimum_non_occurrences <- min(non_occurrences[, -1])
203203
non_occurrences[, -1] <- non_occurrences[, -1] - minimum_non_occurrences
204204
interactions_frame[is.na(as.matrix(interactions_frame))] <- 0
205205
interactions_frame[, -1] <- (interactions_frame[, -1] * occurrences[, -1] +
206-
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth))/(forest$num.trees - minimum_non_occurrences)
206+
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth, nrow = length(mean_tree_depth)))/(forest$num.trees - minimum_non_occurrences)
207207
}
208208
interactions_frame <- reshape2::melt(interactions_frame, id.vars = "variable")
209209
colnames(interactions_frame)[2:3] <- c("root_variable", "mean_min_depth")

tests/testthat.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
library(testthat)
2+
library(randomForestExplainer)
3+
4+
test_check("randomForestExplainer")

tests/testthat/test_randomForest.R

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
library(randomForest)
2+
library(dplyr)
3+
set.seed(12345)
4+
5+
context("Test randomForest classification forests")
6+
forest <- randomForest(Species ~ ., data = iris, localImp = TRUE, ntree = 2)
7+
8+
test_that("measure_importance works", {
9+
imp_df <- measure_importance(forest, mean_sample = "all_trees",
10+
measures = c("mean_min_depth","accuracy_decrease",
11+
"gini_decrease", "no_of_nodes", "times_a_root"))
12+
expect_equal(imp_df$variable, c("Petal.Length", "Petal.Width", "Sepal.Length", "Sepal.Width"))
13+
})
14+
15+
test_that("important_variables works", {
16+
imp_vars <- important_variables(forest, k = 3,
17+
measures = c("mean_min_depth","accuracy_decrease",
18+
"gini_decrease", "no_of_nodes", "times_a_root"))
19+
expect_equal(imp_vars, c("Petal.Width", "Petal.Length", "Sepal.Length"))
20+
})
21+
22+
test_that("min_depth_distribution works", {
23+
min_depth_dist <- min_depth_distribution(forest)
24+
print(min_depth_dist)
25+
expect_equivalent(arrange(min_depth_dist, tree, minimal_depth, variable),
26+
data.frame("tree" = c(1, 1, 1, 1, 2, 2, 2),
27+
"variable"=c("Petal.Width", "Sepal.Length", "Petal.Length", "Sepal.Width", "Petal.Width", "Sepal.Length", "Petal.Length"),
28+
"minimal_depth"=c(0, 1, 2, 4, 0, 1, 3)))
29+
})
30+
31+
test_that("min_depth_interactions works", {
32+
min_depth_int <- min_depth_interactions(forest, vars = c("Petal.Width"))
33+
expect_equal(as.character(min_depth_int$variable), c("Petal.Length", "Petal.Width", "Sepal.Length", "Sepal.Width"))
34+
})

tests/testthat/test_ranger.R

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
library(ranger)
2+
library(dplyr)
3+
set.seed(12345)
4+
5+
context("Test ranger classification forests")
6+
forest <- ranger(Species ~ ., data = iris, probability = TRUE, num.trees = 2, importance = "impurity")
7+
8+
test_that("measure_importance works", {
9+
imp_df <- measure_importance(forest, mean_sample = "all_trees",
10+
measures = c("mean_min_depth", "impurity",
11+
"no_of_nodes", "times_a_root"))
12+
expect_equal(imp_df$variable, c("Petal.Length", "Petal.Width", "Sepal.Length", "Sepal.Width"))
13+
})
14+
15+
test_that("important_variables works", {
16+
imp_vars <- important_variables(forest, k = 3,
17+
measures = c("mean_min_depth", "impurity",
18+
"no_of_nodes", "times_a_root"))
19+
expect_equal(imp_vars, c("Petal.Width", "Sepal.Length", "Petal.Length"))
20+
})
21+
22+
test_that("min_depth_distribution works", {
23+
min_depth_dist <- min_depth_distribution(forest)
24+
print(min_depth_dist)
25+
expect_equivalent(arrange(min_depth_dist, tree, minimal_depth, variable),
26+
data.frame("tree" = c(1, 1, 1, 2, 2, 2),
27+
"variable"=c("Petal.Width", "Sepal.Length", "Petal.Length", "Petal.Width", "Sepal.Length", "Sepal.Width"),
28+
"minimal_depth"=c(0, 2, 3, 0, 3, 3), stringsAsFactors = FALSE))
29+
})
30+
31+
test_that("min_depth_interactions works", {
32+
min_depth_int <- min_depth_interactions(forest, vars = c("Petal.Width"))
33+
expect_equal(as.character(min_depth_int$variable), c("Petal.Length", "Petal.Width", "Sepal.Length", "Sepal.Width"))
34+
})

0 commit comments

Comments
 (0)