From b52adcb56462d459a34a35f649b662de23611402 Mon Sep 17 00:00:00 2001 From: Srikanth K S Date: Thu, 23 May 2024 13:02:37 +0530 Subject: [PATCH] tidy changes (#21) --- DESCRIPTION | 14 +- NAMESPACE | 13 +- R/c5.R | 284 ++++++++++++++++------------------ R/cubist.R | 269 +++++++++++++++----------------- R/generic.R | 23 +-- R/globals.R | 17 ++ R/package.R | 7 +- R/rpart.R | 291 +++++++++++------------------------ R/ruleclasses.R | 34 ++++ R/utils.R | 94 +++++------ R/varSpec.R | 63 ++++---- man/package_tidyrules.Rd | 1 + man/reexports.Rd | 16 ++ man/tidy.C5.0.Rd | 35 +++++ man/tidy.cubist.Rd | 35 +++++ man/tidy.rpart.Rd | 36 +++++ man/tidyRules.C5.0.Rd | 45 ------ man/tidyRules.Rd | 31 ---- man/tidyRules.cubist.Rd | 48 ------ man/tidyRules.rpart.Rd | 42 ----- man/varSpec.Rd | 17 +- tests/testthat/test-c5.R | 56 +++---- tests/testthat/test-cubist.R | 51 +++--- tests/testthat/test-rpart.R | 91 +++++------ 24 files changed, 720 insertions(+), 893 deletions(-) create mode 100644 R/globals.R create mode 100644 R/ruleclasses.R create mode 100644 man/reexports.Rd create mode 100644 man/tidy.C5.0.Rd create mode 100644 man/tidy.cubist.Rd create mode 100644 man/tidy.rpart.Rd delete mode 100644 man/tidyRules.C5.0.Rd delete mode 100644 man/tidyRules.Rd delete mode 100644 man/tidyRules.cubist.Rd delete mode 100644 man/tidyRules.rpart.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 8ef24e5..c0a7329 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Package: tidyrules Type: Package Title: Obtain Rules from Rule Based Models as Tidy Dataframe -Version: 0.1.5 +Version: 0.2.0 Authors@R: c( person("Srikanth", "Komala Sheshachala", email = "sri.teach@gmail.com", role = c("aut", "cre")), person("Amith Kumar", "Ullur Raghavendra", email = "amith54@gmail.com", role = c("aut")) @@ -9,12 +9,15 @@ Authors@R: c( Maintainer: Srikanth Komala Sheshachala Depends: R (>= 3.6.0), Imports: - tibble (>= 2.0.1), stringr (>= 1.3.1), magrittr (>= 1.5), purrr (>= 0.3.2), - assertthat (>= 0.2.0), partykit (>= 1.2.2), + rlang (>= 1.1.3), + generics (>= 0.1.3), + checkmate (>= 2.3.1), + tidytable (>= 0.11.0), + data.table (>= 1.14.6) Suggests: AmesHousing (>= 0.0.3), dplyr (>= 0.8), @@ -28,12 +31,11 @@ Suggests: mlbench (>= 2.1.1), knitr (>= 1.23), rmarkdown (>= 1.13), - pander (>= 0.6.3), -Description: Utility to convert text based summary of rule based models to a tidy dataframe (where each row represents a rule) with related metrics such as support, confidence and lift. Rule based models from these packages are supported: 'C5.0', 'rpart' and 'Cubist'. +Description: Utility to convert text based summary of rule based models to a rulelist or ruleset dataframe (where each row represents a rule) with related metrics such as support, confidence and lift. Rule based models from these packages are supported: 'C5.0', 'rpart' and 'Cubist'. URL: https://github.com/talegari/tidyrules BugReports: https://github.com/talegari/tidyrules/issues License: GPL-3 Encoding: UTF-8 LazyData: true -RoxygenNote: 7.1.0 +RoxygenNote: 7.3.1 VignetteBuilder: knitr diff --git a/NAMESPACE b/NAMESPACE index ad4a78e..253eed9 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,8 +1,13 @@ # Generated by roxygen2: do not edit by hand -S3method(tidyRules,C5.0) -S3method(tidyRules,cubist) -S3method(tidyRules,rpart) -export(tidyRules) +S3method(print,rulelist) +S3method(print,ruleset) +S3method(tidy,C5.0) +S3method(tidy,cubist) +S3method(tidy,rpart) +export(tidy) export(varSpec) +importFrom(data.table,":=") +importFrom(generics,tidy) importFrom(magrittr,"%>%") +importFrom(rlang,"%||%") diff --git a/R/c5.R b/R/c5.R index 4888a5e..cb19ef5 100644 --- a/R/c5.R +++ b/R/c5.R @@ -3,15 +3,13 @@ # https://github.com/talegari/tidyrules with GPL-3 license. ################################################################################ -#' @name tidyRules.C5.0 -#' @title Obtain rules as a tidy tibble from a C5.0 model -#' @description Each row corresponds to a rule. A rule can be copied into -#' `dplyr::filter` to filter the observations corresponding to a rule -#' @author Srikanth KS, \email{sri.teach@@gmail.com} -#' @param object Fitted model object with rules +#' @name tidy.C5.0 +#' @title Obtain rules as rulelist/tiydtable from a C5.0 model +#' @description Each row corresponds to a rule per trial_nbr +#' @param x C5 model fitted with `rules = TRUE` #' @param ... Other arguments (See details) -#' @return A tibble where each row corresponds to a rule. The columns are: -#' support, confidence, lift, lhs, rhs, n_conditions +#' @return A rulelist/tidytable where each row corresponds to a rule. +#' The columns are: rule_nbr, trial_nbr, LHS, RHS, support, confidence, lift #' @details #' #' Optional named arguments: @@ -21,137 +19,106 @@ #' \item laplace(flag, default: TRUE) is supported. This computes confidence #' with laplace correction as documented under 'Rulesets' here: [C5 #' doc](https://www.rulequest.com/see5-unix.html). -#' -#' \item language (string, default: "r"): language where the rules are parsable. -#' The allowed options is one among: r, python, sql -#' #' } #' #' @examples -#' data("attrition", package = "modeldata") -#' attrition <- tibble::as_tibble(attrition) -#' c5_model <- C50::C5.0(Attrition ~., data = attrition, rules = TRUE) +#' c5_model = C50::C5.0(Attrition ~., data = modeldata::attrition, rules = TRUE) #' summary(c5_model) -#' tidyRules(c5_model) +#' tidy(c5_model) #' @export -tidyRules.C5.0 <- function(object, ...){ - - # evaluate ... - arguments = list(...) - if(is.null(arguments[["laplace"]])){ - arguments[["laplace"]] = TRUE - } - - # asserts for 'language' - if(is.null(arguments[["language"]])){ +tidy.C5.0 = function(x, ...){ - arguments[["language"]] = "r" + #### checks ################################################################# - } else { + arguments = list(...) + arguments[["laplace"]] = arguments[["laplace"]] %||% TRUE - assertthat::assert_that(assertthat::is.string(arguments[["language"]])) - arguments[["language"]] = stringr::str_to_lower(arguments[["language"]]) - assertthat::assert_that(arguments[["language"]] %in% c("r" - , "python" - , "sql" - ) - ) - } + # for magrittr dot + . = NULL - # for magrittr dot ---- - . <- NULL - - # check for a rule based model ---- - stopifnot(inherits(object, "C5.0")) - if(!object[["rbm"]]){ - stop( - stringr:: str_c( - "Unable to find rules in the C5.0 model." - , "Model should be built with rules using `rules = TRUE` argument." - ) - ) + if (!x[["rbm"]]){ + rlang::abort("Model should be built using `rules = TRUE` argument.") } - # output of the model ---- - output <- object[["output"]] + # output of the model + output = x[["output"]] - # get variable specification ---- - var_spec <- varSpec(object) - variable_names <- var_spec[["variable"]] - col_classes <- var_spec[["type"]] - names(col_classes) <- variable_names + # get variable specification + var_spec = varSpec(x) + variable_names = var_spec[["variable"]] + col_classes = var_spec[["type"]] + names(col_classes) = variable_names - # throw error if there is consecutive spaces ---- + # throw error if there is consecutive spaces # output from the model squishes the spaces - if(any(stringr::str_count(variable_names, " ") > 0)){ - stop("Variable names should not two or more consecutive spaces.") + if (any(stringr::str_count(variable_names, " ") > 0)){ + rlang::abort("Variable names should not two or more consecutive spaces.") } - # extract rules part ---- - spl <- output %>% - stringr::str_replace_all("\t", "") %>% # remove tab spaces + #### core logic ############################################################## + # extract rules part + spl = + output %>% + stringr::str_replace_all("\t", "") %>% # remove tab spaces stringr::str_replace_all("\n ", "") %>% # handle multiline lineitems - strSplitSingle("\n") # split along newlines + strSplitSingle("\n") # split along newlines # detect where the rules start - start_rules_position <- - which(stringr::str_detect(spl, "^Rule ")) %>% - min() + start_rules_position = min(which(stringr::str_detect(spl, "^Rule "))) # detect where the rules end - end_rules_position <- - which(stringr::str_detect(spl, "^Evaluation on training data")) %>% + end_rules_position = + stringr::str_detect(spl, "^Evaluation on training data") %>% + which() %>% magrittr::subtract(1) %>% min() # get the rules part - spl <- spl[start_rules_position:end_rules_position] %>% + spl = spl[start_rules_position:end_rules_position] %>% stringr::str_squish() %>% removeEmptyLines() - # get raw rules by splitting ---- + ## get raw rules by splitting # every rule starts with 'Rule' - cuts <- which(stringr::str_detect(spl, "^Rule ")) - + cuts = which(stringr::str_detect(spl, "^Rule ")) # end of rule is a line before the start of next rule - #cuts2 <- c(utils::tail(cuts, -1) - 1, length(spl)) - cuts2 <- which(stringr::str_detect(spl, "^\\-\\> ")) + cuts2 = which(stringr::str_detect(spl, "^\\-\\> ")) # split rules - rules_raw <- purrr::map2(cuts, cuts2, function(x, y) spl[x:y]) + rules_raw = purrr::map2(cuts, cuts2, function(x, y) spl[x:y]) - # function to get a parsable rule from a raw rule ---- - getRules <- function(single_raw_rule){ + ## function to get a parsable rule from a raw rule + getRules = function(single_raw_rule){ # empty list container - rule <- list() + rule = list() # get stats from first line ---- - first_line <- single_raw_rule[1] + first_line = single_raw_rule[1] # A typical first line looks like: #************************************************** # "Rule 0/1: (521/30, lift 1.1)", ":" #************************************************** - index <- strSplitSingle(first_line, ":") %>% + index = strSplitSingle(first_line, ":") %>% magrittr::extract(1) %>% strSplitSingle("\\s") %>% magrittr::extract(2) %>% strSplitSingle("/") - if(length(index) == 2){ - rule[["rule_number"]] <- as.integer(index[2]) - rule[["trial_number"]] <- as.integer(index[1]) + 1L + if (length(index) == 2){ + rule[["rule_number"]] = as.integer(index[2]) + rule[["trial_number"]] = as.integer(index[1]) + 1L } else { - rule[["rule_number"]] <- as.integer(index) - rule[["trial_number"]] <- 1L + rule[["rule_number"]] = as.integer(index) + rule[["trial_number"]] = 1L } - stats <- strSplitSingle(first_line, ":") %>% - #stringr::str_split("Rule 0/1: (521/30, lift 1.1)", ":") %>% + stats = + strSplitSingle(first_line, ":") %>% magrittr::extract(2) %>% strSplitSingle("\\(") %>% magrittr::extract(2) %>% @@ -160,20 +127,21 @@ tidyRules.C5.0 <- function(object, ...){ strSplitSingle(",") %>% stringr::str_squish() - support_confidence <- strSplitSingle(stats[1], "/") - if(length(support_confidence) > 1){ + support_confidence = strSplitSingle(stats[1], "/") + if (length(support_confidence) > 1){ # extract support - rule[["support"]] <- as.integer(support_confidence[1]) + rule[["support"]] = as.integer(support_confidence[1]) # compute confidence (not extract) - if(arguments[["laplace"]]){ + if (arguments[["laplace"]]){ # C5 doc computes confidence using laplace correction # (n-m+1)/(n+2) # n: number of obs in leaf # m: number of musclassifications among n - rule[["confidence"]] <- rule[["support"]] %>% + rule[["confidence"]] = + rule[["support"]] %>% magrittr::subtract(as.integer(support_confidence[2])) %>% magrittr::add(1) %>% magrittr::divide_by(rule[["support"]] + 2) %>% @@ -183,7 +151,8 @@ tidyRules.C5.0 <- function(object, ...){ # without laplace correction # simply: (n-m)/n - rule[["confidence"]] <- rule[["support"]] %>% + rule[["confidence"]] = + rule[["support"]] %>% magrittr::subtract(as.integer(support_confidence[2])) %>% magrittr::divide_by(rule[["support"]]) %>% round(4) @@ -191,37 +160,39 @@ tidyRules.C5.0 <- function(object, ...){ } else { - rule[["support"]] <- as.integer(support_confidence) + rule[["support"]] = as.integer(support_confidence) # see comments for laplace above - if(arguments[["laplace"]]){ + if (arguments[["laplace"]]){ rule[["confidence"]] = (rule[["support"]] + 1)/(rule[["support"]] + 2) } else{ rule[["confidence"]] = 1 } } - rule[["lift"]] <- strSplitSingle(stats[2], "\\s") %>% + rule[["lift"]] = + strSplitSingle(stats[2], "\\s") %>% magrittr::extract(2) %>% as.numeric() # curate a single line item of the rule ---- - line_item_curator <- function(line_item){ + line_item_curator = function(line_item){ # in unforeseen cases just return the rule string # let the parsing test catch it - line_item_rule <- line_item + line_item_rule = line_item # 'in' separator for a single line item of rule # ex1: JobInvolvement in [Low-Medium] for ordered factors # ex2: JobRole in {Laboratory_Technician, Sales_Representative} - if(stringr::str_detect(line_item, "\\sin\\s")){ - split_line_item <- strSplitSingle(line_item, "\\sin\\s") - lhs_line_item <- split_line_item[1] - rhs_line_item <- split_line_item[2] + if (stringr::str_detect(line_item, "\\sin\\s")){ + split_line_item = strSplitSingle(line_item, "\\sin\\s") + lhs_line_item = split_line_item[1] + rhs_line_item = split_line_item[2] # unordered factor case - if(stringr::str_detect(line_item, "\\{")){ - rhs_line_item <- rhs_line_item %>% + if (stringr::str_detect(line_item, "\\{")){ + rhs_line_item = + rhs_line_item %>% strHead(-1) %>% # remove quotes strTail(-1) %>% strSplitSingle(",") %>% # split the list by comma @@ -231,81 +202,88 @@ tidyRules.C5.0 <- function(object, ...){ stringr::str_c(collapse = ", ") %>% # bind with comma stringr::str_c("c(", ., ")") # create 'c' structure - line_item_rule <- stringr::str_c(lhs_line_item - , " %in% " - , rhs_line_item - ) + line_item_rule = stringr::str_c(lhs_line_item, + " %in% ", + rhs_line_item + ) } # unordered factor case - if(stringr::str_detect(line_item, "\\[")){ - rhs_line_item <- rhs_line_item %>% + if (stringr::str_detect(line_item, "\\[")){ + rhs_line_item = + rhs_line_item %>% strHead(-1) %>% strTail(-1) # more than one hyphen means some factor level has hyphen - if(stringr::str_count(rhs_line_item, "-") > 1){ - stop("factor levels cannot have '-'.") + if (stringr::str_count(rhs_line_item, "-") > 1){ + rlang::abort("factor levels cannot have '-'.") } - rhs_line_item <- rhs_line_item %>% + rhs_line_item = rhs_line_item %>% strSplitSingle("-") %>% stringr::str_squish() # in case there is space # get the levels of the variable - levels <- var_spec[var_spec[["variable"]] == lhs_line_item, ] %>% + levels = + var_spec[var_spec[["variable"]] == lhs_line_item, ] %>% as.list() %>% magrittr::extract2("levels") %>% magrittr::extract2(1) # get all levels between start and end level - start_level <- which(levels == rhs_line_item[1]) - end_level <- which(levels == rhs_line_item[2]) + start_level = which(levels == rhs_line_item[1]) + end_level = which(levels == rhs_line_item[2]) # construct RHS of the line item - rhs_line_item <- levels[start_level:end_level] %>% + rhs_line_item = + levels[start_level:end_level] %>% stringr::str_c("'", ., "'") %>% stringr::str_c(collapse = ", ") %>% stringr::str_c("c(", ., ")") # complete line rule - line_item_rule <- stringr::str_c(lhs_line_item - , " %in% " - , rhs_line_item - ) + line_item_rule = stringr::str_c(lhs_line_item, + " %in% ", + rhs_line_item + ) } } # handle '=' case # ex: MaritalStatus = Single - contains_equals <- stringr::str_detect(line_item, " = ") - if(contains_equals){ + contains_equals = stringr::str_detect(line_item, " = ") + if (contains_equals){ - sub_rule <- strSplitSingle(line_item, "=") %>% + sub_rule = + strSplitSingle(line_item, "=") %>% stringr::str_trim() - the_class <- col_classes[[ sub_rule[1] ]] + the_class = col_classes[[ sub_rule[1] ]] # quote if non-numeric - if(!(the_class %in% c("numeric", "integer"))){ - sub_rule[2] <- stringr::str_c("'", sub_rule[2], "'") + if (!(the_class %in% c("numeric", "integer"))){ + sub_rule[2] = stringr::str_c("'", sub_rule[2], "'") } - line_item_rule <- stringr::str_c(sub_rule, collapse = " == ") + line_item_rule = stringr::str_c(sub_rule, collapse = " == ") } + line_item_rule = paste0("( ", line_item_rule, " )") return(line_item_rule) } # create LHS and RHS ---- - rule[["LHS"]] <- single_raw_rule %>% + rule[["LHS"]] = + single_raw_rule %>% utils::tail(-1) %>% # remove first stats line utils::head(-1) %>% # remove the RHS line purrr::map(line_item_curator) %>% # get clean rule lines stringr::str_c(collapse = " & ") # concat them with '&' - rule[["RHS"]] <- single_raw_rule %>% + rule[["RHS"]] = + single_raw_rule %>% utils::tail(1) %>% # get the RHS line stringr::str_squish() %>% # remove multispaces strSplitSingle("\\s") %>% # split by space @@ -315,35 +293,37 @@ tidyRules.C5.0 <- function(object, ...){ return(rule) } - # apply rule tidying for each rule and return tibble ---- - res <- purrr::map(rules_raw, getRules) %>% + # apply rule tidying for each rule and return tibble + res = + purrr::map(rules_raw, getRules) %>% purrr::transpose() %>% purrr::simplify_all() %>% - tibble::as_tibble() - - # replace variable names with spaces within backquotes ---- - for(i in 1:length(variable_names)){ - res[["LHS"]] <- stringr::str_replace_all( - res[["LHS"]] - , variable_names[i] - , addBackquotes(variable_names[i]) - ) + tidytable::as_tidytable() + + #### finalize output ######################################################### + # replace variable names with spaces within backquotes + for (i in 1:length(variable_names)){ + res[["LHS"]] = + stringr::str_replace_all(res[["LHS"]], + variable_names[i], + addBackquotes(variable_names[i]) + ) } - # handle the rule parsable language - lang = arguments[["language"]] + #### return ################################################################## + res = + res %>% + tidytable::select(rule_nbr = rule_number, trial_nbr = trial_number, + LHS, RHS, + support, confidence, lift + ) - if (lang == "python"){ - res[["LHS"]] = ruleRToPython(res[["LHS"]]) - } else if (lang == "sql"){ - res[["LHS"]] = ruleRToSQL(res[["LHS"]]) - } + class(res) = c("rulelist", class(res)) + + attr(res, "keys") = "trial_nbr" + attr(res, "model_type") = "C5" + attr(res, "estimation_type") = "classification" - # return ---- - res <- tibble::rowid_to_column(res, "id") - res <- res[, c("id", "LHS", "RHS", "support", "confidence" - , "lift", "rule_number", "trial_number") - ] return(res) } diff --git a/R/cubist.R b/R/cubist.R index ff7fba9..30d3528 100644 --- a/R/cubist.R +++ b/R/cubist.R @@ -3,91 +3,67 @@ # https://github.com/talegari/tidyrules with GPL-3 license. ################################################################################ -#' @name tidyRules.cubist -#' @title Obtain rules as a tidy tibble from a cubist model -#' @description Each row corresponds to a rule. A rule can be copied into -#' `dplyr::filter` to filter the observations corresponding to a rule -#' @author Srikanth KS, \email{sri.teach@@gmail.com} -#' @param object Fitted model object with rules +#' @name tidy.cubist +#' @title Obtain rules as a ruleset/tidytable from a cubist model +#' @description Each row corresponds to a rule per committee. +#' @param x Cubist model #' @param ... Other arguments (currently unused) -#' @return A tibble where each row corresponds to a rule. The columns are: -#' support, mean, min, max, error, lhs, rhs and committee +#' @return A ruleset/tidytable where each row corresponds to a rule. The columns +#' are: rule_nbr, committee, LHS, RHS, support, mean, min, max, error #' @details When col_classes argument is missing, an educated guess is made #' about class by parsing the RHS of sub-rule. This might sometimes not lead #' to a parsable rule. -#' -#' Optional named arguments: -#' -#' \itemize{ -#' -#' \item language (string, default: "r"): language where the rules are -#' parsable. The allowed options is one among: r, python, sql -#' -#' } -#' #' @examples #' data("attrition", package = "modeldata") -#' attrition <- tibble::as_tibble(attrition) -#' cols_att <- setdiff(colnames(attrition), c("MonthlyIncome", "Attrition")) +#' cols_att = setdiff(colnames(attrition), c("MonthlyIncome", "Attrition")) #' -#' cb_att <- -#' Cubist::cubist(x = attrition[, cols_att],y = attrition[["MonthlyIncome"]]) -#' tr_att <- tidyRules(cb_att) -#' tr_att +#' cb_att = Cubist::cubist(x = attrition[, cols_att], +#' y = attrition[["MonthlyIncome"]] +#' ) +#' summary(cb_att) +#' tidy(cb_att) #' @export -tidyRules.cubist <- function(object, ...){ - - # asserts for 'language' - arguments = list(...) - if(is.null(arguments[["language"]])){ - - arguments[["language"]] = "r" - - } else { - - assertthat::assert_that(assertthat::is.string(arguments[["language"]])) - arguments[["language"]] = stringr::str_to_lower(arguments[["language"]]) - assertthat::assert_that(arguments[["language"]] %in% c("r", "python", "sql")) - } +tidy.cubist = function(x, ...){ - # output from the model ---- - output <- object[["output"]] + #### core rule extraction #################################################### + # output from the model + output = x[["output"]] - # get variable specification ---- - var_spec <- varSpec(object) - variable_names <- var_spec[["variable"]] - col_classes <- var_spec[["type"]] - names(col_classes) <- variable_names + # get variable specification + var_spec = varSpec(x) + variable_names = var_spec[["variable"]] + col_classes = var_spec[["type"]] + names(col_classes) = variable_names - # throw error if there is consecutive spaces ---- + # throw error if there is consecutive spaces # output from the model squishes the spaces if(any(stringr::str_count(variable_names, " ") > 0)){ - stop("Variable names should not two or more consecutive spaces.") + rlang::abort("Variable names should not two or more consecutive spaces.") } - variable_names_with_ <- stringr::str_replace_all(variable_names - , "\\s" - , "_" - ) + variable_names_with_ = + stringr::str_replace_all(variable_names, "\\s", "_") - # split by newline and remove emptylines ---- - lev_1 <- object[["output"]] %>% + # split by newline and remove emptylines + lev_1 = + x[["output"]] %>% strSplitSingle("\\n") %>% removeEmptyLines() - # remove everything from 'Evaluation on training data' onwards ---- - evalLine <- stringr::str_which(lev_1, "^Evaluation on training data") - lev_2 <- lev_1[-(evalLine:length(lev_1))] %>% + # remove everything from 'Evaluation on training data' onwards + evalLine = stringr::str_which(lev_1, "^Evaluation on training data") + lev_2 = + lev_1[-(evalLine:length(lev_1))] %>% stringr::str_subset("^(?!Model).*$") - # detect starts and ends of rules ---- - rule_starts <- stringr::str_which(stringr::str_trim(lev_2), "^Rule\\s") + # detect starts and ends of rules + rule_starts = stringr::str_which(stringr::str_trim(lev_2), "^Rule\\s") # end of a rule is a line before the next rule start - rule_ends <- c(utils::tail(rule_starts, -1) - 1, length(lev_2)) + rule_ends = c(utils::tail(rule_starts, -1) - 1, length(lev_2)) - # create a rule list for cubist ---- - get_rules_cubist <- function(single_raw_rule){ + # create a rule list for cubist + get_rules_cubist = function(single_raw_rule){ # a raw rule looks like this: # @@ -115,14 +91,14 @@ tidyRules.cubist <- function(object, ...){ # + 0.003 Garage_Cars + 0.003 Fireplaces + 0.07 Longitude # + 0.001 TotRms_AbvGrd - res <- list() + res = list() # locate the position of square bracket and collect stats - firstLine <- stringr::str_squish(single_raw_rule[1]) - openingSquareBracketPosition <- stringr::str_locate(firstLine, "\\[")[1, 1] + firstLine = stringr::str_squish(single_raw_rule[1]) + openingSquareBracketPosition = stringr::str_locate(firstLine, "\\[")[1, 1] # All stats are at the begining of the rule - stat <- + stat = # between square brackets stringr::str_sub(firstLine , openingSquareBracketPosition + 1 @@ -131,27 +107,27 @@ tidyRules.cubist <- function(object, ...){ strSplitSingle("\\,") %>% stringr::str_trim() - res[["support"]] <- stat[1] %>% + res[["support"]] = stat[1] %>% strSplitSingle("\\s") %>% magrittr::extract(1) %>% as.integer() - res[["mean"]] <- stat[2] %>% + res[["mean"]] = stat[2] %>% strSplitSingle(" ") %>% magrittr::extract(2) %>% as.numeric() - res[["min"]] <- stat[3] %>% + res[["min"]] = stat[3] %>% strSplitSingle(" ") %>% magrittr::extract(2) %>% as.numeric() - res[["max"]] <- stat[3] %>% + res[["max"]] = stat[3] %>% strSplitSingle(" ") %>% magrittr::extract(4) %>% as.numeric() - res[["error"]] <- stat[4] %>% + res[["error"]] = stat[4] %>% strSplitSingle(" ") %>% magrittr::extract(3) %>% as.numeric() @@ -159,15 +135,16 @@ tidyRules.cubist <- function(object, ...){ # is if-then missing (only outcome is there) if_exists = any(stringr::str_trim(single_raw_rule) == "if") - if(if_exists){ + if (if_exists){ # get LHS - btw_if_then <- seq( - which(stringr::str_trim(single_raw_rule) == "if") + 1 - , which(stringr::str_trim(single_raw_rule) == "then") - 1 - ) + btw_if_then = + seq(which(stringr::str_trim(single_raw_rule) == "if") + 1, + which(stringr::str_trim(single_raw_rule) == "then") - 1 + ) # unclean LHS strings, one condition per string - lhsStrings <- single_raw_rule[btw_if_then] %>% + lhsStrings = + single_raw_rule[btw_if_then] %>% stringr::str_replace_all("\\t", "\\\\n") %>% stringr::str_trim() %>% stringr::str_c(collapse = " ") %>% @@ -176,19 +153,20 @@ tidyRules.cubist <- function(object, ...){ stringr::str_trim() # function to get the one clean rule string - getRuleString <- function(string){ + getRuleString = function(string){ # to avoid CRAN notes - . <- NULL + . = NULL # if there is ' in {' in the string if(stringr::str_detect(string, "\\sin\\s\\{")){ # split with ' in {' - var_lvls <- strSplitSingle(string, "\\sin\\s\\{") + var_lvls = strSplitSingle(string, "\\sin\\s\\{") # get the contents inside curly braces - lvls <- var_lvls[2] %>% + lvls = + var_lvls[2] %>% # omit the closing curly bracket strHead(-1) %>% strSplitSingle(",") %>% @@ -198,31 +176,27 @@ tidyRules.cubist <- function(object, ...){ stringr::str_c("c(", ., ")") # get the variable - var <- var_lvls[1] %>% - stringr::str_trim() - - rs <- stringr::str_c(var, " %in% ", lvls) + var = stringr::str_trim(var_lvls[1]) + rs = stringr::str_c(var, " %in% ", lvls) } else { # handle '=' case - contains_equals <- stringr::str_detect(string, " = ") - if(contains_equals){ + contains_equals = stringr::str_detect(string, " = ") + + if (contains_equals){ - sub_rule <- strSplitSingle(string, "=") %>% + sub_rule = strSplitSingle(string, "=") %>% stringr::str_trim() if(!(col_classes[sub_rule[1]] == "numeric")){ - sub_rule[2] <- stringr::str_c("'", sub_rule[2], "'") + sub_rule[2] = stringr::str_c("'", sub_rule[2], "'") } - rs <- stringr::str_c(sub_rule, collapse = " == ") - + rs = stringr::str_c(sub_rule, collapse = " == ") } else { - # nothing to do - rs <- string - + rs = string } } # end of handle '=' case @@ -232,7 +206,9 @@ tidyRules.cubist <- function(object, ...){ } # clean up LHS as string - res[["LHS"]] <- purrr::map_chr(lhsStrings, getRuleString) %>% + res[["LHS"]] = + purrr::map_chr(lhsStrings, getRuleString) %>% + stringr::str_c("( ", ., " )") %>% stringr::str_c(collapse = " & ") # note spaces next to AND } else { @@ -242,22 +218,23 @@ tidyRules.cubist <- function(object, ...){ # get RHS # then might not exist: still retaining old name 'afterThen' - if(if_exists){ - afterThen <- seq(which(stringr::str_trim(single_raw_rule) == "then") + 1 - , length(single_raw_rule) - ) + if (if_exists){ + afterThen = seq(which(stringr::str_trim(single_raw_rule) == "then") + 1, + length(single_raw_rule) + ) } else { - afterThen <- seq( - which(stringr::str_detect(stringr::str_trim(single_raw_rule) - , "^outcome" + afterThen = seq( + which(stringr::str_detect(stringr::str_trim(single_raw_rule), + "^outcome" ) - ) - , length(single_raw_rule) + ), + length(single_raw_rule) ) } # handle brackets around signs - res[["RHS"]] <- single_raw_rule[afterThen] %>% + res[["RHS"]] = + single_raw_rule[afterThen] %>% stringr::str_replace_all("\\t", "") %>% stringr::str_trim() %>% stringr::str_c(collapse = " ") %>% @@ -272,15 +249,16 @@ tidyRules.cubist <- function(object, ...){ stringr::str_replace_all("\\-\\-", ") - (") # quotes aroud each addenum - res[["RHS"]] <- stringr::str_c("(", res[["RHS"]], ")") %>% + res[["RHS"]] = + stringr::str_c("(", res[["RHS"]], ")") %>% # honour negative intercept stringr::str_replace("\\(\\)\\s\\-\\s\\(", "(-") return(res) } - # see if rules have commitees and create commitees vector ---- - rule_number_splits <- + # see if rules have commitees and create commitees vector + rule_number_splits = stringr::str_split(stringr::str_trim(lev_2)[rule_starts], ":") %>% purrr::map_chr(function(x) x[[1]]) %>% stringr::str_split("\\s") %>% @@ -289,56 +267,55 @@ tidyRules.cubist <- function(object, ...){ simplify2array() %>% as.integer() - if(length(rule_number_splits) > length(rule_starts)){ - committees <- rule_number_splits[seq(1 - , by = 2 - , length.out = length(rule_starts) - )] + if (length(rule_number_splits) > length(rule_starts)){ + committees = + rule_number_splits[seq(1, by = 2, length.out = length(rule_starts))] } else { - committees <- rep(1L, length(rule_starts)) + committees = rep(1L, length(rule_starts)) } - # create parsable rules from raw rules ---- - res <- - purrr::map(1:length(rule_starts) - , function(i) lev_2[rule_starts[i]:rule_ends[i]] + # create parsable rules from raw rules + res = + purrr::map(1:length(rule_starts), + function(i) lev_2[rule_starts[i]:rule_ends[i]] ) %>% purrr::map(get_rules_cubist) %>% purrr::transpose() %>% purrr::map(unlist) %>% - tibble::as_tibble() - - # replace variable names with spaces within backquotes ---- - for(i in 1:length(variable_names)){ - res[["LHS"]] <- stringr::str_replace_all( - res[["LHS"]] - , variable_names[i] - , addBackquotes(variable_names[i]) - ) - - res[["RHS"]] <- stringr::str_replace_all( - res[["RHS"]] - , variable_names_with_[i] - , addBackquotes(variable_names[i]) - ) + tidytable::as_tidytable() + + #### prepare and return ###################################################### + # replace variable names with spaces within backquotes + for (i in 1:length(variable_names)){ + res[["LHS"]] = + stringr::str_replace_all(res[["LHS"]], + variable_names[i], + addBackquotes(variable_names[i]) + ) + + res[["RHS"]] = stringr::str_replace_all(res[["RHS"]], + variable_names_with_[i], + addBackquotes(variable_names[i]) + ) } - # handle the rule parsable language - lang = arguments[["language"]] - - if (lang == "python"){ - res[["LHS"]] = ruleRToPython(res[["LHS"]]) - } else if (lang == "sql"){ - res[["LHS"]] = ruleRToSQL(res[["LHS"]]) - } + res = + res %>% + tidytable::mutate(committee = local(committees)) %>% + tidytable::arrange(desc(support), .by = committee) %>% + tidytable::mutate(rule_nbr = tidytable::row_number(), .by = committee) - # prepare and return ---- - res[["committee"]] <- committees - res <- tibble::rowid_to_column(res, "id") - res <- res[, c("id", "LHS", "RHS", "support" - , "mean", "min", "max", "error", "committee" - ) + res = res[, c("rule_nbr", "committee", + "LHS", "RHS", + "support", "mean", "min", "max", "error" + ) ] + class(res) = c("ruleset", class(res)) + + attr(res, "keys") = "committee" + attr(res, "model_type") = "cubist" + attr(res, "estimation_type") = "regression" + return(res) } diff --git a/R/generic.R b/R/generic.R index 5a7dadd..08b6add 100644 --- a/R/generic.R +++ b/R/generic.R @@ -3,22 +3,9 @@ # https://github.com/talegari/tidyrules with GPL-3 license. ################################################################################ -#' @name tidyRules -#' @title Obtain rules as a tidy tibble -#' @description Each row corresponds to a rule. A rule can be copied into -#' `dplyr::filter` to filter the observations corresponding to a rule -#' @details tidyRule supports these rule based models: C5, Cubist and rpart. -#' @author Srikanth KS, \email{sri.teach@@gmail.com} -#' @param object Fitted model object with rules -#' @param ... Other arguments (currently unused) -#' @param col_classes Named list or a named character vector of column classes. -#' Column names of the data used for modeling form the names and the -#' respective classes for the value. One way of obtaining this is by running -#' `lapply(data, class)`. -#' @return A tibble where each row corresponds to a rule -#' @export -tidyRules <- function(object, col_classes = NULL, ...){ - - UseMethod("tidyRules", object) +# dev: generic 'tidy' is now imported from 'generics' package +# 'tidyRules' generic is no longer supported. -} +#' @importFrom generics tidy +#' @export +generics::tidy diff --git a/R/globals.R b/R/globals.R new file mode 100644 index 0000000..f367917 --- /dev/null +++ b/R/globals.R @@ -0,0 +1,17 @@ +utils::globalVariables(c(".", + "LHS", + "RHS", + "committee", + "desc", + "dev", + "lift", + "n", + "predict_class", + "rule_nbr", + "rule_number", + "support", + "trial_number", + "yval", + "confidence" + ) + ) \ No newline at end of file diff --git a/R/package.R b/R/package.R index b1db90d..14e139a 100644 --- a/R/package.R +++ b/R/package.R @@ -6,10 +6,9 @@ #' @name package_tidyrules #' @title About 'tidyrules' package #' @description Obtain rules as tidy dataframes - #' @importFrom magrittr %>% +#' @importFrom rlang %||% +#' @importFrom data.table := "_PACKAGE" -is.integerish <- getFromNamespace("is.integerish", "assertthat") - -list.rules.party <- getFromNamespace(".list.rules.party", "partykit") +list.rules.party = getFromNamespace(".list.rules.party", "partykit") diff --git a/R/rpart.R b/R/rpart.R index 873ed48..463323f 100644 --- a/R/rpart.R +++ b/R/rpart.R @@ -3,238 +3,129 @@ # https://github.com/talegari/tidyrules with GPL-3 license. ################################################################################ -#' @name tidyRules.rpart -#' @title Obtain rules as a tidy tibble from a rpart model +#' @name tidy.rpart +#' @title Obtain rules as a ruleset/tidytable from a rpart model #' @description Each row corresponds to a rule. A rule can be copied into #' `dplyr::filter` to filter the observations corresponding to a rule -#' @author Amith Kumar U R, \email{amith54@@gmail.com} -#' @param object Fitted model object with rules +#' @param x rpart model #' @param ... Other arguments (currently unused) #' @details NOTE: For rpart rules, one should build the model without #' \bold{ordered factor} variable. We recommend you to convert \bold{ordered #' factor} to \bold{factor} or \bold{integer} class. -#' -#' Optional named arguments: -#' -#' \itemize{ -#' -#' \item language (string, default: "r"): language where the rules are parsable. -#' The allowed options is one among: r, python, sql -#' -#' } -#' -#' @return A tibble where each row corresponds to a rule. The columns are: -#' support, confidence, lift, LHS, RHS +#' @return A tidytable where each row corresponds to a rule. The columns are: +#' rule_nbr, LHS, RHS, support, confidence (for classification only), lift +#' (for classification only) #' @examples -#' iris_rpart <- rpart::rpart(Species ~ .,data = iris) -#' tidyRules(iris_rpart) +#' rpart_class = rpart::rpart(Species ~ .,data = iris) +#' rpart_class +#' tidy(rpart_class) +#' +#' rpart_regr = rpart::rpart(Sepal.Length ~ .,data = iris) +#' rpart_regr +#' tidy(rpart_regr) #' @export -tidyRules.rpart <- function(object, ...){ - # asserts for 'language' - arguments = list(...) - if(is.null(arguments[["language"]])){ +tidy.rpart = function(x, ...){ - arguments[["language"]] = "r" - - } else { - - assertthat::assert_that(assertthat::is.string(arguments[["language"]])) - arguments[["language"]] = stringr::str_to_lower(arguments[["language"]]) - assertthat::assert_that(arguments[["language"]] %in% c("r", "python", "sql")) - } - - # check for rpart object - stopifnot(inherits(object, "rpart")) + ##### assertions and prep #################################################### + arguments = list(...) - if(object$method == "class"){ - # Stop if only root node is present in the object - if(nrow(object$frame) == 1){ - stop(paste0("Only root is present in the rpart object" + # supported 'rpart' classes + method_rpart = x$method + # classification: class, regression: anova + checkmate::assert_choice(method_rpart, c("class", "anova")) - ) - ) + # build with y = TRUE + if (is.null(x$y)){ + rlang::abort("rpart model should be built using argument `y = TRUE`.") } - # Stop if any ordered factor is present: - # partykit doesn't handle the ordered factors while processing rules. - if(sum(object$ordered) > 0){ - stop(paste0("Ordered variables detected!!" - , "convert ordered variables" - , " to factor or numberic before model fit")) - } + # column names from the x: This will be used at the end to handle the + # variables with a space + col_names = + attr(x$terms, which = "term.labels") %>% + stringr::str_remove_all(pattern = "`") - if(is.null(object$y)){ - stop( - stringr::str_c( - "Unable to find target variable in the model object!! " - , "Model should be built using argument `y = TRUE`." - ) - ) + # throw error if there are consecutive spaces in the column names + if (any(stringr::str_count(col_names, " ") > 0)){ + rlang::abort( + "Variable names should not have two or more consecutive spaces.") } - # column names from the object: This will be used at the end to handle the - # variables with a space. - col_names <- stringr::str_remove_all(attr(object$terms,which = "term.labels") - , pattern = "`") - - # throw error if there are consecutive spaces in the column names ---- - if(any(stringr::str_count(col_names, " ") > 0)){ - stop("Variable names should not have two or more consecutive spaces.") - } + #### core extraction work #################################################### # convert to class "party" - party_obj <- partykit::as.party(object) + party_obj = partykit::as.party(x) - # extracting rules - rules <- list.rules.party(party_obj) %>% + # extract rules + rules = + list.rules.party(party_obj) %>% stringr::str_replace_all(pattern = "\\\"","'") %>% stringr::str_remove_all(pattern = ", 'NA'") %>% stringr::str_remove_all(pattern = "'NA',") %>% stringr::str_remove_all(pattern = "'NA'") %>% - stringr::str_squish() - - # terminal nodes from party object - terminal_nodes <- partykit::nodeids(party_obj, terminal = T) - - # extract metrics from rpart object - metrics <- object$frame[terminal_nodes,c("n","dev","yval")] - metrics$confidence <- (metrics$n + 1 - metrics$dev) / (metrics$n + 2) - - metrics <- metrics[,c("n","yval","confidence")] %>% - magrittr::set_colnames(c("support","predict_class","confidence")) - - # prevelance for lift calculation - prevelance <- object$y %>% - table() %>% - prop.table() %>% - as.numeric() - - # Actual labels for RHS - metrics$RHS <- attr(object, "ylevels")[metrics$predict_class] + stringr::str_squish() %>% + stringr::str_split(" & ") %>% + purrr::map(~ stringr::str_c("( ", .x, " )")) %>% + purrr::map_chr(~ stringr::str_c(.x, collapse = " & ")) + + terminal_nodes = partykit::nodeids(party_obj, terminal = TRUE) + + # create metrics df + if (method_rpart == "class"){ + prevalence = as.numeric(prop.table(table(x$y))) + + res = + x$frame[terminal_nodes, c("n", "dev", "yval")] %>% + tidytable::mutate(confidence = (n + 1 - dev) / (n + 2)) %>% + tidytable::rename(support = n, predict_class = yval) %>% + tidytable::mutate(RHS = attr(x, "ylevels")[predict_class]) %>% + tidytable::mutate(prevalence = prevalence[predict_class]) %>% + tidytable::mutate(lift = confidence / prevalence) %>% + tidytable::mutate(LHS = rules) + + } else if (method_rpart == "anova"){ + res = + x$frame[terminal_nodes, c("n","yval")] %>% + tidytable::rename(support = n, RHS = yval) %>% + tidytable::mutate(LHS = rules) + } - metrics$prevelance <- prevelance[metrics$predict_class] + #### finalize output ######################################################### - metrics$lift <- metrics$confidence / metrics$prevelance + # replace variable names with spaces within backquotes + for (i in 1:length(col_names)) { + res[["LHS"]] = + stringr::str_replace_all(res[["LHS"]], + col_names[i], + addBackquotes(col_names[i]) + ) + } - metrics$LHS <- rules + #### return ################################################################## + res[["rule_nbr"]] = 1:nrow(res) - tidy_rules <- metrics + if (method_rpart == "class"){ + res = + res %>% + tidytable::select(rule_nbr, LHS, RHS, support, confidence, lift) - # replace variable names with spaces within backquotes ---- - for(i in 1:length(col_names)){ - tidy_rules[["LHS"]] <- stringr::str_replace_all( - tidy_rules[["LHS"]] - , col_names[i] - , addBackquotes(col_names[i]) - ) + } else if (method_rpart == "anova") { + res = + res %>% + tidytable::select(rule_nbr, LHS, RHS, support) } - # return ---- - tidy_rules <- tibble::rowid_to_column(tidy_rules, "id") - tidy_rules <- tidy_rules[, c("id" - , "LHS" - , "RHS" - , "support" - , "confidence" - , "lift") - ] %>% - tibble::as_tibble() - - # handle the rule parsable language - lang = arguments[["language"]] - - if (lang == "python"){ - res[["LHS"]] = ruleRToPython(res[["LHS"]]) - } else if (lang == "sql"){ - res[["LHS"]] = ruleRToSQL(res[["LHS"]]) - } + class(res) = c("ruleset", class(res)) - return(tidy_rules) - - } else { - # Stop if only root node is present in the object - if(nrow(object$frame) == 1){ - stop(paste0("Only root is present in the rpart object" - - ) - ) - } - - # Stop if any ordered factor is present: - # partykit doesn't handle the ordered factors while processing rules. - if(sum(object$ordered) > 0){ - stop(paste0("Ordered variables detected!!" - , "convert ordered variables" - , " to factor or numberic before model fit")) - } - - # column names from the object: This will be used at the end to handle the - # variables with a space. - col_names <- stringr::str_remove_all(attr(object$terms,which = "term.labels") - , pattern = "`") - - # throw error if there are consecutive spaces in the column names ---- - if(any(stringr::str_count(col_names, " ") > 0)){ - stop("Variable names should not have two or more consecutive spaces.") - } - - # convert to class "party" - party_obj <- partykit::as.party(object) - - # extracting rules - rules <- list.rules.party(party_obj) %>% - stringr::str_replace_all(pattern = "\\\"","'") %>% - stringr::str_remove_all(pattern = ", 'NA'") %>% - stringr::str_remove_all(pattern = "'NA',") %>% - stringr::str_remove_all(pattern = "'NA'") %>% - stringr::str_squish() - - # terminal nodes from party object - terminal_nodes <- partykit::nodeids(party_obj, terminal = T) - - # extract metrics from rpart object - metrics <- object$frame[terminal_nodes,c("n","dev","yval")] - # metrics$confidence <- (metrics$n + 1 - metrics$dev) / (metrics$n + 2) - - metrics <- metrics[,c("n","yval")] %>% - magrittr::set_colnames(c("support","RHS")) - - metrics$LHS <- rules - - tidy_rules <- metrics - - # replace variable names with spaces within backquotes ---- - for(i in 1:length(col_names)){ - tidy_rules[["LHS"]] <- stringr::str_replace_all( - tidy_rules[["LHS"]] - , col_names[i] - , addBackquotes(col_names[i]) - ) - } - - # return ---- - tidy_rules <- tibble::rowid_to_column(tidy_rules, "id") - tidy_rules <- tidy_rules[, c("id" - , "LHS" - , "RHS" - , "support" - ) - ] %>% - tibble::as_tibble() - - # handle the rule parsable language - lang = arguments[["language"]] - - if (lang == "python"){ - res[["LHS"]] = ruleRToPython(res[["LHS"]]) - } else if (lang == "sql"){ - res[["LHS"]] = ruleRToSQL(res[["LHS"]]) - } - - return(tidy_rules) - } + attr(res, "keys") = NULL + attr(res, "model_type") = "rpart" + if (method_rpart == "class"){ + attr(res, "estimation_type") = "classification" + } else if (method_rpart == "anova"){ + attr(res, "estimation_type") = "regression" + } + return(res) } - diff --git a/R/ruleclasses.R b/R/ruleclasses.R new file mode 100644 index 0000000..6ab977b --- /dev/null +++ b/R/ruleclasses.R @@ -0,0 +1,34 @@ +################################################################################ +# This is the part of the 'tidyrules' R package hosted at +# https://github.com/talegari/tidyrules with GPL-3 license. +################################################################################ + +#' @export +print.ruleset = function(x, ...){ + + rlang::inform(paste0("# A ruleset/tidytable with keys: ", + paste(attr(x, "keys"), collapse = ", ") + ) + ) + + class(x) = setdiff(class(x), "ruleset") + print(x, ...) + class(x) = c("ruleset", class(x)) + + return(invisible(x)) +} + +#' @export +print.rulelist = function(x, ...){ + + rlang::inform(paste0("# A rulelist/tidytable with keys: ", + paste(attr(x, "keys"), collapse = ", ") + ) + ) + + class(x) = setdiff(class(x), "rulelist") + print(x, ...) + class(x) = c("rulelist", class(x)) + + return(invisible(x)) +} diff --git a/R/utils.R b/R/utils.R index a9daf1c..9901f5f 100644 --- a/R/utils.R +++ b/R/utils.R @@ -15,19 +15,19 @@ #' tidyrules:::positionSpaceOutsideSinglequotes(c("hello", "hel' 'o ")) #' } #' -positionSpaceOutsideSinglequotes <- Vectorize( +positionSpaceOutsideSinglequotes = Vectorize( function(string){ - assertthat::assert_that(assertthat::is.string(string)) + checkmate::assert_string(string) - fullsplit <- strsplit(string, "")[[1]] - is_singlequote <- (fullsplit == "'") - parity_singlequote_left <- cumsum(fullsplit == "'") %% 2 - space_position <- which(fullsplit == " ") + fullsplit = strsplit(string, "")[[1]] + is_singlequote = (fullsplit == "'") + parity_singlequote_left = cumsum(fullsplit == "'") %% 2 + space_position = which(fullsplit == " ") space_position[parity_singlequote_left[space_position] == 0] - } - , USE.NAMES = FALSE + }, + USE.NAMES = FALSE ) #' @name removeEmptyLines @@ -40,7 +40,7 @@ positionSpaceOutsideSinglequotes <- Vectorize( #' tidyrules:::removeEmptyLines(c("abc", "", "d")) #' } #' -removeEmptyLines <- function(strings){ +removeEmptyLines = function(strings){ strings[!(strings == "")] } @@ -55,9 +55,9 @@ removeEmptyLines <- function(strings){ #' tidyrules:::strSplitSingle("abc,d", ",") #' } #' -strSplitSingle <- function(string, pattern){ +strSplitSingle = function(string, pattern){ - assertthat::assert_that(assertthat::is.string(string)) + checkmate::assert_string(string) stringr::str_split(string, pattern)[[1]] } @@ -75,22 +75,22 @@ strSplitSingle <- function(string, pattern){ #' tidyrules:::strHead(c("string", "string2"), -1) #' } #' -strHead <- Vectorize( +strHead = Vectorize( function(string, n){ - assertthat::assert_that(assertthat::is.string(string)) - len <- stringr::str_length(string) - assertthat::assert_that(is.integerish(n) && length(n) == 1 && n != 0 - , msg = "'n' should be an integer" - ) + checkmate::assert_string(string) + + len = stringr::str_length(string) + checkmate::assert_integerish(n, len = 1) + checkmate::assert(n != 0) if(n < 0){ - n <- len + n + n = len + n } return( stringr::str_sub(string, 1, n) ) - } - , vectorize.args = "string" - , USE.NAMES = FALSE + }, + vectorize.args = "string", + USE.NAMES = FALSE ) @@ -108,22 +108,22 @@ strHead <- Vectorize( #' tidyrules:::strTail(c("string", "string2"), -1) #' } #' -strTail <- Vectorize( +strTail = Vectorize( function(string, n){ - assertthat::assert_that(assertthat::is.string(string)) - len <- stringr::str_length(string) - assertthat::assert_that(is.integerish(n) && length(n) == 1 && n != 0 - , msg = "'n' should be an integer" - ) - if(n < 0){ - n <- len + n + checkmate::assert_string(string) + + len = stringr::str_length(string) + checkmate::assert_integerish(n, len = 1) + checkmate::assert(n != 0) + if (n < 0){ + n = len + n } return( stringr::str_sub(string, len - n + 1, len) ) - } - , vectorize.args = "string" - , USE.NAMES = FALSE + }, + vectorize.args = "string", + USE.NAMES = FALSE ) #' @name addBackquotes #' @title Add backquotes @@ -135,17 +135,19 @@ strTail <- Vectorize( #' tidyrules:::addBackquotes(c("ab", "a b")) #' } #' -addBackquotes <- Vectorize( +addBackquotes = Vectorize( function(string){ - res <- string - if(stringr::str_count(string, "\\s") > 0){ + checkmate::assert_string(string) + + res = string + if (stringr::str_count(string, "\\s") > 0){ if(strHead(string, 1) != "`" && strTail(string, 1) != "`"){ - res <- stringr::str_c("`", string, "`") + res = stringr::str_c("`", string, "`") } } return(res) - } - , USE.NAMES = FALSE + }, + USE.NAMES = FALSE ) #' @name strReplaceReduce @@ -160,18 +162,16 @@ addBackquotes <- Vectorize( #' tidyrules:::strReplaceReduce("abcd", c("ab", "dc"), c("cd", "ab")) #' } #' -strReplaceReduce <- function(string, pattern, replacement){ +strReplaceReduce = function(string, pattern, replacement){ stopifnot(length(pattern) == length(replacement)) - io <- list(pattern, replacement) %>% + io = + list(pattern, replacement) %>% purrr::transpose() %>% purrr::map(unlist) - purrr::reduce(io - , function(x, y) stringr::str_replace_all(x - , y[[1]] - , y[[2]] - ) - , .init = string - ) + purrr::reduce(io, + function(x, y) stringr::str_replace_all(x, y[[1]], y[[2]]), + .init = string + ) } diff --git a/R/varSpec.R b/R/varSpec.R index 7709462..b77b8d4 100644 --- a/R/varSpec.R +++ b/R/varSpec.R @@ -6,85 +6,94 @@ #' @name varSpec #' @title Get variable specification for a Cubist/C5 object #' @description Obtain variable names, type (numeric, ordered, factor) and -#' levels as a tibble -#' @author Srikanth KS, \email{sri.teach@@gmail.com} +#' levels as a tidytable #' @param object Cubist/C5 object -#' @return A tibble with three columns: variable(character), type(character) and -#' levels(a list-column). For numeric variables, levels are set to NA. +#' @return A tidytable with three columns: variable(character), type(character) +#' and levels(a list-column). For numeric variables, levels are set to NA. #' @examples #' data("attrition", package = "modeldata") -#' attrition <- tibble::as_tibble(attrition) -#' cols_att <- setdiff(colnames(attrition), c("MonthlyIncome", "Attrition")) +#' cols_att = setdiff(colnames(attrition), c("MonthlyIncome", "Attrition")) #' -#' cb_att <- -#' Cubist::cubist(x = attrition[, cols_att],y = attrition[["MonthlyIncome"]]) +#' cb_att = Cubist::cubist(x = attrition[, cols_att], +#' y = attrition[["MonthlyIncome"]] +#' ) #' varSpec(cb_att) #' @export -varSpec <- function(object){ +varSpec = function(object){ # 1. split ny newline # 2. remove a few header lines # 3. get variables and details - lines_raw <- object[["names"]] %>% + lines_raw = + object[["names"]] %>% strSplitSingle("\\n") - outcome_line_number <- stringr::str_which(lines_raw, "^outcome:") + outcome_line_number = stringr::str_which(lines_raw, "^outcome:") - lines <- lines_raw[-(1:outcome_line_number)] %>% + lines = + lines_raw[-(1:outcome_line_number)] %>% removeEmptyLines() - split_lines <- lines %>% + split_lines = + lines %>% stringr::str_split(":") %>% purrr::transpose() - variables <- split_lines %>% + variables = + split_lines %>% magrittr::extract2(1) %>% unlist() %>% stringr::str_replace_all("\\\\", "") # clean up variable names - details <- split_lines %>% + details = + split_lines %>% magrittr::extract2(2) %>% unlist() %>% stringr::str_trim() # handle a detail depending on its type - handleDetail <- function(adetail){ + handleDetail = function(adetail){ - if(adetail == "continuous."){ + if (adetail == "continuous."){ # handle numeric/integer - out <- list(type = "numeric", levels = NA_character_) + out = list(type = "numeric", levels = NA_character_) - } else if(stringr::str_detect(adetail, "^\\[ordered\\]")){ + } else if (stringr::str_detect(adetail, "^\\[ordered\\]")){ # handle ordered factors - levels <- adetail %>% + levels = + adetail %>% strSplitSingle("\\[ordered\\]") %>% magrittr::extract(2) %>% strHead(-1) %>% strSplitSingle(",") %>% stringr::str_trim() - out <- list(type = "ordered", levels = levels) + out = list(type = "ordered", levels = levels) } else { # handle unordered factors - levels <- adetail %>% + levels = + adetail %>% strHead(-1) %>% strSplitSingle(",") %>% stringr::str_trim() - out <- list(type = "factor", levels = levels) + out = list(type = "factor", levels = levels) } + return(out) } - details_cleaned <- details %>% + details_cleaned = + details %>% purrr::map(handleDetail) %>% purrr::transpose() - details_cleaned[["type"]] <- unlist(details_cleaned[["type"]]) - details_cleaned[["variable"]] <- variables + details_cleaned[["type"]] = unlist(details_cleaned[["type"]]) + details_cleaned[["variable"]] = variables - return(tibble::as_tibble(details_cleaned)) + res = tidytable::as_tidytable(details_cleaned) + return(res) } diff --git a/man/package_tidyrules.Rd b/man/package_tidyrules.Rd index cc5f450..b9aee74 100644 --- a/man/package_tidyrules.Rd +++ b/man/package_tidyrules.Rd @@ -4,6 +4,7 @@ \name{package_tidyrules} \alias{tidyrules} \alias{tidyrules-package} +\alias{package_tidyrules} \title{About 'tidyrules' package} \description{ Obtain rules as tidy dataframes diff --git a/man/reexports.Rd b/man/reexports.Rd new file mode 100644 index 0000000..9b6c624 --- /dev/null +++ b/man/reexports.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/generic.R +\docType{import} +\name{reexports} +\alias{reexports} +\alias{tidy} +\title{Objects exported from other packages} +\keyword{internal} +\description{ +These objects are imported from other packages. Follow the links +below to see their documentation. + +\describe{ + \item{generics}{\code{\link[generics]{tidy}}} +}} + diff --git a/man/tidy.C5.0.Rd b/man/tidy.C5.0.Rd new file mode 100644 index 0000000..dfdc3a5 --- /dev/null +++ b/man/tidy.C5.0.Rd @@ -0,0 +1,35 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/c5.R +\name{tidy.C5.0} +\alias{tidy.C5.0} +\title{Obtain rules as rulelist/tiydtable from a C5.0 model} +\usage{ +\method{tidy}{C5.0}(x, ...) +} +\arguments{ +\item{x}{C5 model fitted with `rules = TRUE`} + +\item{...}{Other arguments (See details)} +} +\value{ +A rulelist/tidytable where each row corresponds to a rule. + The columns are: rule_nbr, trial_nbr, LHS, RHS, support, confidence, lift +} +\description{ +Each row corresponds to a rule per trial_nbr +} +\details{ +Optional named arguments: + +\itemize{ + +\item laplace(flag, default: TRUE) is supported. This computes confidence +with laplace correction as documented under 'Rulesets' here: [C5 +doc](https://www.rulequest.com/see5-unix.html). +} +} +\examples{ +c5_model = C50::C5.0(Attrition ~., data = modeldata::attrition, rules = TRUE) +summary(c5_model) +tidy(c5_model) +} diff --git a/man/tidy.cubist.Rd b/man/tidy.cubist.Rd new file mode 100644 index 0000000..48eb23e --- /dev/null +++ b/man/tidy.cubist.Rd @@ -0,0 +1,35 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/cubist.R +\name{tidy.cubist} +\alias{tidy.cubist} +\title{Obtain rules as a ruleset/tidytable from a cubist model} +\usage{ +\method{tidy}{cubist}(x, ...) +} +\arguments{ +\item{x}{Cubist model} + +\item{...}{Other arguments (currently unused)} +} +\value{ +A ruleset/tidytable where each row corresponds to a rule. The columns + are: rule_nbr, committee, LHS, RHS, support, mean, min, max, error +} +\description{ +Each row corresponds to a rule per committee. +} +\details{ +When col_classes argument is missing, an educated guess is made + about class by parsing the RHS of sub-rule. This might sometimes not lead + to a parsable rule. +} +\examples{ +data("attrition", package = "modeldata") +cols_att = setdiff(colnames(attrition), c("MonthlyIncome", "Attrition")) + +cb_att = Cubist::cubist(x = attrition[, cols_att], + y = attrition[["MonthlyIncome"]] + ) +summary(cb_att) +tidy(cb_att) +} diff --git a/man/tidy.rpart.Rd b/man/tidy.rpart.Rd new file mode 100644 index 0000000..c2b1011 --- /dev/null +++ b/man/tidy.rpart.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/rpart.R +\name{tidy.rpart} +\alias{tidy.rpart} +\title{Obtain rules as a ruleset/tidytable from a rpart model} +\usage{ +\method{tidy}{rpart}(x, ...) +} +\arguments{ +\item{x}{rpart model} + +\item{...}{Other arguments (currently unused)} +} +\value{ +A tidytable where each row corresponds to a rule. The columns are: + rule_nbr, LHS, RHS, support, confidence (for classification only), lift + (for classification only) +} +\description{ +Each row corresponds to a rule. A rule can be copied into + `dplyr::filter` to filter the observations corresponding to a rule +} +\details{ +NOTE: For rpart rules, one should build the model without +\bold{ordered factor} variable. We recommend you to convert \bold{ordered +factor} to \bold{factor} or \bold{integer} class. +} +\examples{ +rpart_class = rpart::rpart(Species ~ .,data = iris) +rpart_class +tidy(rpart_class) + +rpart_regr = rpart::rpart(Sepal.Length ~ .,data = iris) +rpart_regr +tidy(rpart_regr) +} diff --git a/man/tidyRules.C5.0.Rd b/man/tidyRules.C5.0.Rd deleted file mode 100644 index 2992149..0000000 --- a/man/tidyRules.C5.0.Rd +++ /dev/null @@ -1,45 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/c5.R -\name{tidyRules.C5.0} -\alias{tidyRules.C5.0} -\title{Obtain rules as a tidy tibble from a C5.0 model} -\usage{ -\method{tidyRules}{C5.0}(object, ...) -} -\arguments{ -\item{object}{Fitted model object with rules} - -\item{...}{Other arguments (See details)} -} -\value{ -A tibble where each row corresponds to a rule. The columns are: - support, confidence, lift, lhs, rhs, n_conditions -} -\description{ -Each row corresponds to a rule. A rule can be copied into - `dplyr::filter` to filter the observations corresponding to a rule -} -\details{ -Optional named arguments: - -\itemize{ - -\item laplace(flag, default: TRUE) is supported. This computes confidence -with laplace correction as documented under 'Rulesets' here: [C5 -doc](https://www.rulequest.com/see5-unix.html). - -\item language (string, default: "r"): language where the rules are parsable. -The allowed options is one among: r, python, sql - -} -} -\examples{ -data("attrition", package = "modeldata") -attrition <- tibble::as_tibble(attrition) -c5_model <- C50::C5.0(Attrition ~., data = attrition, rules = TRUE) -summary(c5_model) -tidyRules(c5_model) -} -\author{ -Srikanth KS, \email{sri.teach@gmail.com} -} diff --git a/man/tidyRules.Rd b/man/tidyRules.Rd deleted file mode 100644 index ba9f072..0000000 --- a/man/tidyRules.Rd +++ /dev/null @@ -1,31 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/generic.R -\name{tidyRules} -\alias{tidyRules} -\title{Obtain rules as a tidy tibble} -\usage{ -tidyRules(object, col_classes = NULL, ...) -} -\arguments{ -\item{object}{Fitted model object with rules} - -\item{col_classes}{Named list or a named character vector of column classes. -Column names of the data used for modeling form the names and the -respective classes for the value. One way of obtaining this is by running -`lapply(data, class)`.} - -\item{...}{Other arguments (currently unused)} -} -\value{ -A tibble where each row corresponds to a rule -} -\description{ -Each row corresponds to a rule. A rule can be copied into - `dplyr::filter` to filter the observations corresponding to a rule -} -\details{ -tidyRule supports these rule based models: C5, Cubist and rpart. -} -\author{ -Srikanth KS, \email{sri.teach@gmail.com} -} diff --git a/man/tidyRules.cubist.Rd b/man/tidyRules.cubist.Rd deleted file mode 100644 index 0cc9317..0000000 --- a/man/tidyRules.cubist.Rd +++ /dev/null @@ -1,48 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/cubist.R -\name{tidyRules.cubist} -\alias{tidyRules.cubist} -\title{Obtain rules as a tidy tibble from a cubist model} -\usage{ -\method{tidyRules}{cubist}(object, ...) -} -\arguments{ -\item{object}{Fitted model object with rules} - -\item{...}{Other arguments (currently unused)} -} -\value{ -A tibble where each row corresponds to a rule. The columns are: - support, mean, min, max, error, lhs, rhs and committee -} -\description{ -Each row corresponds to a rule. A rule can be copied into - `dplyr::filter` to filter the observations corresponding to a rule -} -\details{ -When col_classes argument is missing, an educated guess is made - about class by parsing the RHS of sub-rule. This might sometimes not lead - to a parsable rule. - - Optional named arguments: - - \itemize{ - - \item language (string, default: "r"): language where the rules are - parsable. The allowed options is one among: r, python, sql - - } -} -\examples{ -data("attrition", package = "modeldata") -attrition <- tibble::as_tibble(attrition) -cols_att <- setdiff(colnames(attrition), c("MonthlyIncome", "Attrition")) - -cb_att <- - Cubist::cubist(x = attrition[, cols_att],y = attrition[["MonthlyIncome"]]) -tr_att <- tidyRules(cb_att) -tr_att -} -\author{ -Srikanth KS, \email{sri.teach@gmail.com} -} diff --git a/man/tidyRules.rpart.Rd b/man/tidyRules.rpart.Rd deleted file mode 100644 index b570d79..0000000 --- a/man/tidyRules.rpart.Rd +++ /dev/null @@ -1,42 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/rpart.R -\name{tidyRules.rpart} -\alias{tidyRules.rpart} -\title{Obtain rules as a tidy tibble from a rpart model} -\usage{ -\method{tidyRules}{rpart}(object, ...) -} -\arguments{ -\item{object}{Fitted model object with rules} - -\item{...}{Other arguments (currently unused)} -} -\value{ -A tibble where each row corresponds to a rule. The columns are: - support, confidence, lift, LHS, RHS -} -\description{ -Each row corresponds to a rule. A rule can be copied into - `dplyr::filter` to filter the observations corresponding to a rule -} -\details{ -NOTE: For rpart rules, one should build the model without -\bold{ordered factor} variable. We recommend you to convert \bold{ordered -factor} to \bold{factor} or \bold{integer} class. - -Optional named arguments: - -\itemize{ - -\item language (string, default: "r"): language where the rules are parsable. -The allowed options is one among: r, python, sql - -} -} -\examples{ -iris_rpart <- rpart::rpart(Species ~ .,data = iris) -tidyRules(iris_rpart) -} -\author{ -Amith Kumar U R, \email{amith54@gmail.com} -} diff --git a/man/varSpec.Rd b/man/varSpec.Rd index 194e0f5..6b9b8e8 100644 --- a/man/varSpec.Rd +++ b/man/varSpec.Rd @@ -10,22 +10,19 @@ varSpec(object) \item{object}{Cubist/C5 object} } \value{ -A tibble with three columns: variable(character), type(character) and - levels(a list-column). For numeric variables, levels are set to NA. +A tidytable with three columns: variable(character), type(character) + and levels(a list-column). For numeric variables, levels are set to NA. } \description{ Obtain variable names, type (numeric, ordered, factor) and - levels as a tibble + levels as a tidytable } \examples{ data("attrition", package = "modeldata") -attrition <- tibble::as_tibble(attrition) -cols_att <- setdiff(colnames(attrition), c("MonthlyIncome", "Attrition")) +cols_att = setdiff(colnames(attrition), c("MonthlyIncome", "Attrition")) -cb_att <- - Cubist::cubist(x = attrition[, cols_att],y = attrition[["MonthlyIncome"]]) +cb_att = Cubist::cubist(x = attrition[, cols_att], + y = attrition[["MonthlyIncome"]] + ) varSpec(cb_att) } -\author{ -Srikanth KS, \email{sri.teach@gmail.com} -} diff --git a/tests/testthat/test-c5.R b/tests/testthat/test-c5.R index e35ed94..5012215 100644 --- a/tests/testthat/test-c5.R +++ b/tests/testthat/test-c5.R @@ -1,6 +1,6 @@ ################################################################################ -# This is the part of the 'tidyrules' R package hosted at -# https://github.com/talegari/tidyrules with GPL-3 license. +# This is the part of the 'tidy' R package hosted at +# https://github.com/talegari/tidy with GPL-3 license. ################################################################################ context("test-c5") @@ -8,45 +8,42 @@ context("test-c5") # setup some models ---- # attrition data("attrition", package = "modeldata") -attrition <- tibble::as_tibble(attrition) -c5_att <- C50::C5.0(Attrition ~ ., data = attrition, rules = TRUE) -tr_att <- tidyRules(c5_att) -tr_att_python = tidyRules(c5_att, language = "python") -tr_att_sql = tidyRules(c5_att, language = "sql") +c5_att = C50::C5.0(Attrition ~ ., data = attrition, rules = TRUE) +tr_att = tidy(c5_att) # attrition with trials -c5_att_2 <- C50::C5.0(Attrition ~ ., data = attrition +c5_att_2 = C50::C5.0(Attrition ~ ., data = attrition , trials = 7, rules = TRUE) -tr_att_2 <- tidyRules(c5_att_2) +tr_att_2 = tidy(c5_att_2) # ames housing # ames has some space in Sale_Type levels -ames <- AmesHousing::make_ames() +ames = AmesHousing::make_ames() ames -cb_ames <- C50::C5.0(MS_SubClass ~ ., data = ames +cb_ames = C50::C5.0(MS_SubClass ~ ., data = ames , trials = 3, rules = TRUE) -tr_ames <- tidyRules(cb_ames) +tr_ames = tidy(cb_ames) # column name has a space in it -ames <- AmesHousing::make_ames() -ames_2 <- ames -colnames(ames_2)[which(colnames(ames_2) == "Bldg_Type")] <- "Bldg Type" -colnames(ames_2)[which(colnames(ames_2) == "House_Style")] <- "House Style" -c5_ames_2 <- C50::C5.0(MS_SubClass ~ ., data = ames_2, rules = TRUE) -tr_ames_2 <- tidyRules(c5_ames_2) +ames = AmesHousing::make_ames() +ames_2 = ames +colnames(ames_2)[which(colnames(ames_2) == "Bldg_Type")] = "Bldg Type" +colnames(ames_2)[which(colnames(ames_2) == "House_Style")] = "House Style" +c5_ames_2 = C50::C5.0(MS_SubClass ~ ., data = ames_2, rules = TRUE) +tr_ames_2 = tidy(c5_ames_2) # function to check whether a rule is filterable -ruleFilterable <- function(rule, data){ +ruleFilterable = function(rule, data){ dplyr::filter(data, eval(parse(text = rule))) } # function to check whether all rules are filterable -allRulesFilterable <- function(tr, data){ - parse_status <- sapply( +allRulesFilterable = function(tr, data){ + parse_status = sapply( tr[["LHS"]] , function(arule){ - trydf <- try(ruleFilterable(arule, data) + trydf = try(ruleFilterable(arule, data) , silent = TRUE ) if(nrow(trydf) == 0){ @@ -61,10 +58,10 @@ allRulesFilterable <- function(tr, data){ # test output type ---- test_that("creates tibble", { - expect_is(tr_att, "tbl_df") - expect_is(tr_att_2, "tbl_df") - expect_is(tr_ames, "tbl_df") - expect_is(tr_ames_2, "tbl_df") + expect_is(tr_att, "rulelist") + expect_is(tr_att_2, "rulelist") + expect_is(tr_ames, "rulelist") + expect_is(tr_ames_2, "rulelist") }) # test NA ---- @@ -82,10 +79,3 @@ test_that("rules are parsable", { expect_true(all(allRulesFilterable(tr_ames, ames))) expect_true(all(allRulesFilterable(tr_ames_2, ames_2))) }) - -# test language conversion is successfull ---- -# cannot test parsability -test_that("python and SQL rule conversions work (parsability not checked)", { - expect_is(tr_att_python, "tbl_df") - expect_is(tr_att_sql, "tbl_df") -}) diff --git a/tests/testthat/test-cubist.R b/tests/testthat/test-cubist.R index cfc8828..3f60e3f 100644 --- a/tests/testthat/test-cubist.R +++ b/tests/testthat/test-cubist.R @@ -1,6 +1,6 @@ ################################################################################ -# This is the part of the 'tidyrules' R package hosted at -# https://github.com/talegari/tidyrules with GPL-3 license. +# This is the part of the 'tidy' R package hosted at +# https://github.com/talegari/tidy with GPL-3 license. ################################################################################ context("test-cubist") @@ -8,51 +8,50 @@ context("test-cubist") # setup some models ---- # attrition data("attrition", package = "modeldata") -attrition <- tibble::as_tibble(attrition) -cols_att <- setdiff(colnames(attrition), c("MonthlyIncome", "Attrition")) +cols_att = setdiff(colnames(attrition), c("MonthlyIncome", "Attrition")) -cb_att <- +cb_att = Cubist::cubist(x = attrition[, cols_att], y = attrition[["MonthlyIncome"]] ) -tr_att <- tidyRules(cb_att) +tr_att = tidy(cb_att) # attrition with commitees -cb_att_2 <- +cb_att_2 = Cubist::cubist(x = attrition[, cols_att], y = attrition[["MonthlyIncome"]], committees = 7 ) -tr_att_2 <- tidyRules(cb_att_2) +tr_att_2 = tidy(cb_att_2) # ames housing -ames <- AmesHousing::make_ames() -cb_ames <- Cubist::cubist(x = ames[, setdiff(colnames(ames), c("Sale_Price"))], +ames = AmesHousing::make_ames() +cb_ames = Cubist::cubist(x = ames[, setdiff(colnames(ames), c("Sale_Price"))], y = log10(ames[["Sale_Price"]]), committees = 3 ) -tr_ames <- tidyRules(cb_ames) +tr_ames = tidy(cb_ames) # column name has a space in it data("Boston", package = "MASS") -boston_2 <- Boston -names(boston_2)[6] <- "r m" -names(boston_2)[13] <- "l stat" -cb_boston <- Cubist::cubist(x = boston_2[, -14], y = boston_2[[14]]) -tr_boston <- tidyRules(cb_boston) +boston_2 = Boston +names(boston_2)[6] = "r m" +names(boston_2)[13] = "l stat" +cb_boston = Cubist::cubist(x = boston_2[, -14], y = boston_2[[14]]) +tr_boston = tidy(cb_boston) # function to check whether a rule is filterable -ruleFilterable <- function(rule, data){ +ruleFilterable = function(rule, data){ dplyr::filter(data, eval(parse(text = rule))) } # function to check whether all rules are filterable -allRulesFilterable <- function(tr, data){ - parse_status <- sapply( +allRulesFilterable = function(tr, data){ + parse_status = sapply( tr[["LHS"]] , function(arule){ - trydf <- try(ruleFilterable(arule, data) + trydf = try(ruleFilterable(arule, data) , silent = TRUE ) if(nrow(trydf) == 0){ @@ -65,11 +64,11 @@ allRulesFilterable <- function(tr, data){ } # evaluate RHS -evalRHS <- function(tr, data){ +evalRHS = function(tr, data){ message(deparse(substitute(data))) - with_RHS <- sapply(tr[["RHS"]], + with_RHS = sapply(tr[["RHS"]], function(x){ try(data %>% dplyr::mutate(RHS_ = eval(parse(text = x))) %>% @@ -87,10 +86,10 @@ evalRHS <- function(tr, data){ # test output type ---- test_that("creates tibble", { - expect_is(tr_att, "tbl_df") - expect_is(tr_att_2, "tbl_df") - expect_is(tr_ames, "tbl_df") - expect_is(tr_boston, "tbl_df") + expect_is(tr_att, "ruleset") + expect_is(tr_att_2, "ruleset") + expect_is(tr_ames, "ruleset") + expect_is(tr_boston, "ruleset") }) # test NA ---- diff --git a/tests/testthat/test-rpart.R b/tests/testthat/test-rpart.R index 136c042..3b844d4 100644 --- a/tests/testthat/test-rpart.R +++ b/tests/testthat/test-rpart.R @@ -6,72 +6,58 @@ context("test-rpart") # setup some models ---- -# attrition data("attrition", package = "modeldata") -attrition_1 <- attrition %>% - dplyr::mutate_if(is.ordered, function(x) x <- factor(x,ordered = F)) %>% - dplyr::mutate(Attrition = factor(Attrition, levels = c("No","Yes"))) -rpart_att <- rpart::rpart(Attrition ~ ., data = attrition_1) -tr_att <- tidyRules(rpart_att) +# classification test +attrition_class = + attrition %>% + tidytable::mutate(tidytable::across(is.ordered, ~ factor(.x, ordered = F))) %>% + tidytable::mutate(Attrition = factor(Attrition, levels = c("No", "Yes"))) -# with ordered variables -attrition_2 <- attrition %>% - dplyr::mutate(Attrition = factor(Attrition, levels = c("No","Yes"))) - -rpart_att_1 <- rpart::rpart(Attrition ~ ., data = attrition_2) - -# attrition with maxdepth -rpart_att_2 <- rpart::rpart(Attrition ~ . - , data = attrition_1) - -tr_att_2 <- tidyRules(rpart_att_2) +rpart_att = rpart::rpart(Attrition ~ ., data = attrition_class) +tr_att_class = tidy(rpart_att) # regression test -attrition_reg <- attrition %>% - dplyr::mutate_if(is.ordered, function(x) x <- factor(x,ordered = F)) %>% - dplyr::select(-Attrition) +attrition_reg = + attrition %>% + tidytable::mutate(tidytable::across(is.ordered, ~ factor(.x, ordered = F))) %>% + tidytable::select(-Attrition) -rpart_att_reg <- rpart::rpart(MonthlyIncome ~ . - , data = attrition_reg) -tr_att_reg <- tidyRules(rpart_att_reg) +rpart_att_reg = rpart::rpart(MonthlyIncome ~ ., data = attrition_reg) +tr_att_reg = tidy(rpart_att_reg) # BreastCancer data(BreastCancer, package = "mlbench") -bc <- BreastCancer %>% +bc = BreastCancer %>% dplyr::select(-Id) %>% - dplyr::mutate_if(is.ordered, function(x) x <- factor(x,ordered = F)) + dplyr::mutate_if(is.ordered, function(x) x = factor(x,ordered = F)) -bc_1m <- rpart::rpart(Class ~ ., data = bc) +bc_1m = rpart::rpart(Class ~ ., data = bc) -tr_bc_1 <- tidyRules(bc_1m) +tr_bc_1 = tidy(bc_1m) # variables with spaces -bc2 <- bc +bc2 = bc -colnames(bc2)[which(colnames(bc2) == "Cell.size")] <- "Cell size" -colnames(bc2)[which(colnames(bc2) == "Cell.shape")] <- "Cell shape" +colnames(bc2)[which(colnames(bc2) == "Cell.size")] = "Cell size" +colnames(bc2)[which(colnames(bc2) == "Cell.shape")] = "Cell shape" -bc_2m <- rpart::rpart(Class ~ ., data = bc2) +bc_2m = rpart::rpart(Class ~ ., data = bc2) -tr_bc_2 <- tidyRules(bc_2m) +tr_bc_2 = tidy(bc_2m) # function to check whether a rule is filterable -ruleFilterable <- function(rule, data){ +ruleFilterable = function(rule, data){ dplyr::filter(data, eval(parse(text = rule))) } # function to check whether all rules are filterable -allRulesFilterable <- function(tr, data){ - parse_status <- sapply( - tr[["LHS"]] - , function(arule){ - trydf <- try(ruleFilterable(arule, data) - , silent = TRUE - ) - if(nrow(trydf) == 0){ - print(arule) - } +allRulesFilterable = function(tr, data){ + parse_status = sapply( + tr[["LHS"]], + function(arule){ + trydf = try(ruleFilterable(arule, data), silent = TRUE) + if (nrow(trydf) == 0) print(arule) inherits(trydf, "data.frame") } ) @@ -80,22 +66,20 @@ allRulesFilterable <- function(tr, data){ # test for error while ordered features are present ---- test_that("check error",{ - expect_error(tidyRules(rpart_att_1))}) + expect_error(tidy(rpart_att_1))}) # test output type ---- -test_that("creates tibble", { - expect_is(tr_att, "tbl_df") - expect_is(tr_att_2, "tbl_df") - expect_is(tr_bc_1, "tbl_df") - expect_is(tr_bc_2, "tbl_df") - expect_is(tr_att_reg, "tbl_df") +test_that("creates ruleset", { + expect_is(tr_att_class, "ruleset") + expect_is(tr_bc_1, "ruleset") + expect_is(tr_bc_2, "ruleset") + expect_is(tr_att_reg, "ruleset") }) # test NA ---- test_that("Are NA present", { - expect_false(anyNA(tr_att)) - expect_false(anyNA(tr_att_2)) + expect_false(anyNA(tr_att_class)) expect_false(anyNA(tr_bc_1)) expect_false(anyNA(tr_bc_2)) expect_false(anyNA(tr_att_reg)) @@ -103,8 +87,7 @@ test_that("Are NA present", { # test parsable ---- test_that("rules are parsable", { - expect_true(all(allRulesFilterable(tr_att, attrition))) - expect_true(all(allRulesFilterable(tr_att_2, attrition))) + expect_true(all(allRulesFilterable(tr_att_class, attrition))) expect_true(all(allRulesFilterable(tr_bc_1, bc))) expect_true(all(allRulesFilterable(tr_bc_2, bc2))) expect_true(all(allRulesFilterable(tr_att_reg,attrition)))