Skip to content

Commit

Permalink
support added to party models (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
talegari authored May 30, 2024
1 parent a09aa2e commit 7bae457
Show file tree
Hide file tree
Showing 8 changed files with 300 additions and 6 deletions.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: tidyrules
Type: Package
Title: Obtain Rules from Rule Based Models as Tidy Dataframe
Version: 0.2.1
Version: 0.2.2
Authors@R: c(
person("Srikanth", "Komala Sheshachala", email = "[email protected]", role = c("aut", "cre")),
person("Amith Kumar", "Ullur Raghavendra", email = "[email protected]", role = c("aut"))
Expand All @@ -18,6 +18,8 @@ Imports:
checkmate (>= 2.3.1),
tidytable (>= 0.11.0),
data.table (>= 1.14.6),
DescTools,
MetricsWeighted
Suggests:
AmesHousing (>= 0.0.3),
dplyr (>= 0.8),
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ S3method(predict,ruleset)
S3method(print,rulelist)
S3method(print,ruleset)
S3method(tidy,C5.0)
S3method(tidy,constparty)
S3method(tidy,cubist)
S3method(tidy,rpart)
export(convert_rule_flavor)
Expand All @@ -14,4 +15,6 @@ importFrom(data.table,":=")
importFrom(generics,tidy)
importFrom(magrittr,"%>%")
importFrom(rlang,"%||%")
importFrom(stats,IQR)
importFrom(stats,weighted.mean)
importFrom(utils,data)
10 changes: 9 additions & 1 deletion R/globals.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ utils::globalVariables(c(".",
"rn__",
"row_nbr",
"pref__",
"data"
"data",
"weight",
"response",
"terminal_node_id",
"sum_weight",
"prevalence",
"winning_response",
"average",
"RMSE"
)
)
2 changes: 2 additions & 0 deletions R/package.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#' @importFrom rlang %||%
#' @importFrom data.table :=
#' @importFrom utils data
#' @importFrom stats IQR
#' @importFrom stats weighted.mean
"_PACKAGE"

list.rules.party = getFromNamespace(".list.rules.party", "partykit")
176 changes: 176 additions & 0 deletions R/party.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
################################################################################
# This is the part of the 'tidyrules' R package hosted at
# https://github.com/talegari/tidyrules with GPL-3 license.
################################################################################

#' @name tidy.constparty
#' @title Obtain rules as a ruleset/tidytable from a party model
#' @description Each row corresponds to a rule. A rule can be copied into
#' `dplyr::filter` to filter the observations corresponding to a rule
#' @param x party model
#' @param ... Other arguments (currently unused)
#' @details These party models are supported: regression (y is numeric),
#' classification (y is factor)
#' @return A tidytable where each row corresponds to a rule. The columns are:
#' rule_nbr, LHS, RHS, support, confidence (for classification only), lift
#' (for classification only)
#' @examples
#' model_party_cl = partykit::ctree(species ~ .,data = palmerpenguins::penguins)
#' model_party_cl
#' tidy(model_party_cl)
#'
#' model_party_re = partykit::ctree(bill_length_mm ~ .,
#' data = palmerpenguins::penguins
#' )
#' model_party_re
#' tidy(model_party_re)
#' @export

tidy.constparty = function(x, ...){

##### assertions and prep ####################################################
arguments = list(...)

# column names from the x: This will be used at the end to handle the
# variables with a space
col_names =
attr(x$terms, which = "term.labels") %>%
stringr::str_remove_all(pattern = "`")

# throw error if there are consecutive spaces in the column names
if (any(stringr::str_count(col_names, " ") > 0)){
rlang::abort(
"Variable names should not have two or more consecutive spaces.")
}

# detect method using 'fitted'
fitted_df = tidytable::as_tidytable(x$fitted)
colnames(fitted_df) = c("terminal_node_id", "weight", "response")
fitted_df[["terminal_node_id"]] = as.character(fitted_df[["terminal_node_id"]])

y_class = class(fitted_df[["response"]])
if (y_class == "factor") {
type = "classification"
} else if (y_class %in% c("numeric", "integer")) {
type = "regression"
} else {
rlang::inform("tidy supports only classification and regression 'party' models")
rlang::abort("Unsupported party object")
}

#### core extraction work ####################################################

# extract rules
raw_rules = list.rules.party(x)

rules_df =
raw_rules %>%
stringr::str_replace_all(pattern = "\\\"","'") %>%
stringr::str_remove_all(pattern = ", 'NA'") %>%
stringr::str_remove_all(pattern = "'NA',") %>%
stringr::str_remove_all(pattern = "'NA'") %>%
stringr::str_squish() %>%
stringr::str_split(" & ") %>%
purrr::map(~ stringr::str_c("( ", .x, " )")) %>%
purrr::map_chr(~ stringr::str_c(.x, collapse = " & ")) %>%
tidytable::tidytable(LHS = .) %>%
tidytable::mutate(terminal_node_id = names(raw_rules))

# create metrics df
if (type == "classification"){

terminal_response_df =
fitted_df %>%
tidytable::summarise(sum_weight = sum(weight, na.rm = TRUE),
.by = c(terminal_node_id, response)
) %>%
tidytable::slice_max(n = 1,
order_by = sum_weight,
by = terminal_node_id,
with_ties = FALSE
) %>%
tidytable::select(terminal_node_id,
winning_response = response
)

prevalence_df =
fitted_df %>%
tidytable::summarise(prevalence = sum(weight, na.rm = TRUE),
.by = response
) %>%
tidytable::mutate(prevalence = prevalence / sum(prevalence)) %>%
tidytable::select(response, prevalence)

res =
fitted_df %>%
# bring 'winning_response' column
tidytable::left_join(terminal_response_df,
by = "terminal_node_id"
) %>%
# bring 'prevalence' column
tidytable::left_join(prevalence_df,
by = c("winning_response" = "response")
) %>%
tidytable::summarise(
support = sum(weight),
confidence = weighted.mean(response == winning_response, weight, na.rm = TRUE),
lift = weighted.mean(response == winning_response, weight, na.rm = TRUE) / prevalence[1],
RHS = winning_response[1],
.by = terminal_node_id
) %>%
tidytable::left_join(rules_df, by = "terminal_node_id") %>%
tidytable::arrange(tidytable::desc(confidence)) %>%
tidytable::mutate(., rule_nbr = 1:nrow(.)) %>%
tidytable::select(rule_nbr, LHS, RHS,
support, confidence, lift,
terminal_node_id
)

} else if (type == "regression"){

res =
fitted_df %>%
tidytable::mutate(average = weighted.mean(response, weight, na.rm = TRUE),
.by = terminal_node_id
) %>%
tidytable::summarise(
support = sum(weight),
IQR = DescTools::IQRw(response, weight, na.rm = TRUE),
RMSE = MetricsWeighted::rmse(actual = response,
predicted = average,
w = weight,
na.rm = TRUE
),
average = mean(average),
.by = terminal_node_id
) %>%
tidytable::left_join(rules_df, by = "terminal_node_id") %>%
tidytable::arrange(tidytable::desc(RMSE)) %>%
tidytable::mutate(., rule_nbr = 1:nrow(.)) %>%
tidytable::select(rule_nbr, LHS, RHS = average,
support, IQR, RMSE,
terminal_node_id
)
}

#### finalize output #########################################################

# replace variable names with spaces within backquotes
for (i in 1:length(col_names)) {
res[["LHS"]] =
stringr::str_replace_all(res[["LHS"]],
col_names[i],
addBackquotes(col_names[i])
)
}

#### return ##################################################################

class(res) = c("ruleset", class(res))

attr(res, "keys") = NULL
attr(res, "model_type") = "constparty"
attr(res, "estimation_type") = type

return(res)
}
23 changes: 19 additions & 4 deletions R/rule_translators.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,34 @@ convert_rule_flavor = function(rule, flavor){
if (flavor == "python"){
res =
rule %>%
stringr::str_replace_all("\\( ", "") %>%
stringr::str_replace_all(" \\)", "") %>%

stringr::str_replace_all("%in%", "in") %>%
stringr::str_replace_all("c\\(", "[") %>%
stringr::str_replace_all("\\)", "]") %>%
stringr::str_replace_all("&", "and")

stringr::str_replace_all("&", " ) and (") %>%

stringr::str_c("( ", ., " )") %>%
stringr::str_squish()

} else if (flavor == "sql"){
res =
rule %>%
stringr::str_replace_all("==", "=") %>%
stringr::str_replace_all("\\( ", "") %>%
stringr::str_replace_all(" \\)", "") %>%

stringr::str_replace_all("%in%", "IN") %>%
stringr::str_replace_all("c\\(", "(") %>%
stringr::str_replace_all("&", "AND")
stringr::str_replace_all("c\\(", "[") %>%
stringr::str_replace_all("\\)", "]") %>%

stringr::str_replace_all("&", " ) AND (") %>%

stringr::str_c("( ", ., " )") %>%
stringr::str_squish()
}

attr(res, "flavor") = flavor
return(res)
}
37 changes: 37 additions & 0 deletions man/tidy.constparty.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions tests/testthat/test-party.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
################################################################################
# This is the part of the 'tidyrules' R package hosted at
# https://github.com/talegari/tidyrules with GPL-3 license.
################################################################################

context("test-party")

# setup some models ----
data("penguins", package = "palmerpenguins")

model_party_cl = partykit::ctree(species ~ .,data = penguins)
model_party_cl
tidy(model_party_cl)

model_party_re = partykit::ctree(bill_length_mm ~ .,
data = penguins
)
model_party_re
tidy(model_party_re)

# function to check whether a rule is filterable
ruleFilterable = function(rule, data){
dplyr::filter(data, eval(parse(text = rule)))
}

# function to check whether all rules are filterable
allRulesFilterable = function(tr, data){
parse_status = sapply(
tr[["LHS"]],
function(arule){
trydf = try(ruleFilterable(arule, data), silent = TRUE)
if (nrow(trydf) == 0) print(arule)
inherits(trydf, "data.frame")
}
)
return(parse_status)
}

# test output type ----

test_that("creates ruleset", {
expect_is(tidy(model_party_cl), "ruleset")
expect_is(tidy(model_party_re), "ruleset")
})

# test parsable ----
test_that("rules are parsable", {
expect_true(all(allRulesFilterable(tidy(model_party_cl), penguins)))
expect_true(all(allRulesFilterable(tidy(model_party_re), penguins)))
})

0 comments on commit 7bae457

Please sign in to comment.