-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict_qrf_fun.R
117 lines (104 loc) · 4.79 KB
/
predict_qrf_fun.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Code from Laura Poggio ISRIC from public release of updated Soil Grids paper
# DOI: https://doi.org/10.5194/soil-2020-65
# https://git.wur.nl/isric/soilgrids/soilgrids/-/blob/master/models/ranger/predict_qrf_fun.R
# Modified ranger func. so that it can compute the mean for quantile regression
# forests (QRF) as well (as opposed to just the median i.e. 50th quantile)
# We will modify the tweaked function so that it also works for OOB predictions
# where num.trees != random.node.values.oob (see line 101)
predict.ranger.tree <- function(object, data = NULL, predict.all = FALSE,
num.trees = object$num.trees,
type = "response", se.method = "infjack",
quantiles = c(0.1, 0.5, 0.9),
seed = NULL, num.threads = 1,
verbose = TRUE, ...) {
forest <- object$forest
if (is.null(forest)) {
stop("Error: No saved forest in ranger object. Please set write.forest to TRUE when calling ranger.")
}
if (object$importance.mode %in% c("impurity_corrected", "impurity_unbiased")) {
warning("Forest was grown with 'impurity_corrected' variable importance. For prediction it is advised to grow another forest without this importance setting.")
}
if (type == "quantiles") {
## Quantile prediction
if (object$treetype != "Regression") {
stop("Error: Quantile prediction implemented only for regression outcomes.")
}
if (is.null(object$random.node.values)) {
stop("Error: Set quantreg=TRUE in ranger(...) for quantile prediction.")
}
if (is.null(data)) {
## OOB prediction
if (is.null(object$random.node.values.oob)) {
stop("Error: Set keep.inbag=TRUE in ranger(...) for out-of-bag quantile prediction or provide new data in predict(...).")
}
node.values <- object$random.node.values.oob
} else {
## New data prediction
terminal.nodes <- predict(object, data, type = "terminalNodes")$predictions + 1
node.values <- 0 * terminal.nodes
for (tree in 1:num.trees) {
node.values[, tree] <- object$random.node.values[terminal.nodes[, tree], tree]
}
}
## Prepare results
result <- list(num.samples = nrow(node.values),
treetype = object$treetype,
num.independent.variables = object$num.independent.variables,
num.trees = num.trees)
class(result) <- "ranger.prediction"
## Compute quantiles of distribution
result$predictions <- t(apply(node.values, 1, quantile, quantiles, na.rm=TRUE))
if (nrow(result$predictions) != result$num.samples) {
## Fix result for single quantile
result$predictions <- t(result$predictions)
}
colnames(result$predictions) <- paste("quantile=", quantiles)
result
} else if (type == "treepred") {
## Single tree predictions
if (object$treetype != "Regression") {
stop("Error: Quantile prediction implemented only for regression outcomes.")
}
if (is.null(object$random.node.values)) {
stop("Error: Set quantreg=TRUE in ranger(...) for quantile prediction.")
}
if (is.null(data)) {
## OOB prediction
if (is.null(object$random.node.values.oob)) {
stop("Error: Set keep.inbag=TRUE in ranger(...) for out-of-bag quantile prediction or provide new data in predict(...).")
}
node.values <- object$random.node.values.oob
} else {
## New data prediction
terminal.nodes <- predict(object, data, type = "terminalNodes")$predictions + 1
node.values <- 0 * terminal.nodes
for (tree in 1:num.trees) {
node.values[, tree] <- object$random.node.values[terminal.nodes[, tree], tree]
}
}
## Prepare results
result <- list(num.samples = nrow(node.values),
treetype = object$treetype,
num.independent.variables = object$num.independent.variables,
num.trees = num.trees)
class(result) <- "ranger.prediction"
## assign single tree predictions
result$predictions <- node.values
# MODIFY FOR WHEN num.trees != random.node.values.oob
# this will always be the case for predicting OOB samples because only trees
# are used for predicting samples that were OOB...
if (is.null(data)) {
colnames(result$predictions) <- paste0("tree_", 1:ncol(object$random.node.values.oob))
result
} else {
colnames(result$predictions) <- paste0("tree_", 1:num.trees)
result
}
} else {
## Non-quantile prediction
if (is.null(data)) {
stop("Error: Argument 'data' is required for non-quantile prediction.")
}
predict(forest, data, predict.all, num.trees, type, se.method, seed, num.threads, verbose, object$inbag.counts, ...)
}
}