Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve access to <distribution> names #398

Merged
merged 11 commits into from
Oct 21, 2024
44 changes: 30 additions & 14 deletions R/accessors.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ get_parameters.epiparameter <- function(x, ...) {
)
}

# convert to meanlog and sdlog names
# if truncated get underlying distribution type
params <- .clean_params(
prob_distribution = family(x),
prob_distribution = family(x, base_dist = is_truncated(x)),
prob_distribution_params = params
)
} else {
Expand Down Expand Up @@ -138,21 +138,37 @@ get_citation.multi_epiparameter <- function(x, ...) {
multi_bibentry
}

#' Gets the distributions names from a mixture distribution
#' [distributional::dist_mixture()]
#'
#' @param x An `<epiparameter>` object.
#' Get the underlying distributions names from a `<distribution>` object from
#' the \pkg{distributional} package in \R distribution naming convention.
#'
#' @details Get and standardise distribution name. For untransformed
#' distributions (e.g. [distributional::dist_gamma()]) it will return the
#' distribution name. For transformed distributions (e.g.
#' [distributional::dist_mixture()]) it will get the name of the underlying
#' distribution(s) by default (`base_dist = TRUE`). Distribution names are
#' returned in the \R naming style (e.g. lognormal is `"lnorm"`). When
#' `base_dist = FALSE` transformed distributions return the name of the
#' transformation (e.g. `"mixture"`).
#'
#' @param x An `<distribution>` object.
#' @param base_dist A boolean `logical` for whether to return the name of a
#' transformed distribution (e.g. `"mixture"` or `"truncated"`) or the
#' underlying distribution type (e.g. `"gamma"` or `"lnorm"`). Default is
#' `TRUE`.
#'
#' @return A `character` vector.
#' @keywords internal
#' @noRd
.get_mixture_family <- function(x) {
assert_epiparameter(x)
fam <- vapply(
unclass(unclass(x$prob_distribution)[[1]])[[1]],
family,
FUN.VALUE = character(1)
)
.distributional_family <- function(x, base_dist = TRUE) {
if (family(x) %in% c("mixture", "truncated") && base_dist) {
fam <- vapply(
distributional::parameters(x)$dist[[1]],
family,
FUN.VALUE = character(1),
USE.NAMES = FALSE
)
} else {
fam <- family(x)
}
fam <- vapply(
fam, function(x) {
switch(x,
Expand Down
2 changes: 1 addition & 1 deletion R/coercion.R
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ is_epiparameter_df <- function(x) {
distributional::parameters(x$prob_distribution[[1]])
)
}
if (identical(stats::family(prob_dist), "truncated")) {
if (identical(family(prob_dist), "truncated")) {
truncation <- distributional::parameters(x$prob_distribution)$upper
} else {
truncation <- NA_real_
Expand Down
85 changes: 25 additions & 60 deletions R/epiparameter.R
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,11 @@ format.epiparameter <- function(x, ...) {
fam <- family(x)
# isTRUE to control for family returning NA for unparameterised
if (isTRUE(fam == "mixture")) {
fam <- paste(fam, toString(.get_mixture_family(x)), sep = ": ")
fam <- paste(
fam,
toString(.distributional_family(x$prob_distribution)),
sep = ": "
)
}
writeLines(sprintf(dist_string, fam))
} else {
Expand Down Expand Up @@ -682,15 +686,8 @@ discretise.epiparameter <- function(x, ...) {
)
prob_dist_params <- prob_dist_params[-idx]

# trunc dist family is truncated so get prob dist by unclassing dist and
# extracting name
list_dist <- unclass(x$prob_distribution)
prob_dist <- gsub(
pattern = "dist_",
replacement = "",
x = class(list_dist[[1]][[1]])[1],
fixed = TRUE
)
# get underlying distribution family that's truncated
prob_dist <- .distributional_family(x$prob_distribution)
}

# standardise distribution parameter names
Expand Down Expand Up @@ -727,6 +724,10 @@ discretise.default <- function(x, ...) {
#'
#' @param object An `<epiparameter>` object.
#' @inheritParams stats::family
#' @param base_dist A boolean `logical` for whether to return the name of a
#' transformed distribution (e.g. `"mixture"` or `"truncated"`) or the
#' underlying distribution type (e.g. `"gamma"` or `"lnorm"`). Default is
#' `FALSE`.
#'
#' @return A character string with the name of the distribution, or `NA` when
#' the `<epiparameter>` object is unparameterised.
Expand Down Expand Up @@ -756,38 +757,19 @@ discretise.default <- function(x, ...) {
#' )
#' )
#' family(ep)
family.epiparameter <- function(object, ...) {
if (inherits(object$prob_distribution, "distcrete")) {
prob_dist <- object$prob_distribution$name
} else if (inherits(object$prob_distribution, "distribution")) {
if (is_truncated(object)) {
prob_dist <- gsub(
pattern = "dist_",
replacement = "",
x = class(unclass(unclass(object$prob_distribution)[[1]])[[1]])[1],
fixed = TRUE
)
} else {
prob_dist <- stats::family(object$prob_distribution)
}
} else if (is.character(object$prob_distribution)) {
prob_dist <- object$prob_distribution
} else {
return(NA)
}
family.epiparameter <- function(object, ..., base_dist = FALSE) {
checkmate::assert_logical(base_dist, any.missing = FALSE, len = 1)
if (inherits(object$prob_distribution, "distcrete"))
return(object$prob_distribution$name)

prob_dist <- switch(prob_dist,
lognormal = "lnorm",
negbin = "nbinom",
geometric = "geom",
poisson = "pois",
normal = "norm",
exponential = "exp",
prob_dist
)
if (inherits(object$prob_distribution, "distribution"))
return(.distributional_family(object$prob_distribution, base_dist))

# return prob dist
prob_dist
if (is.character(object$prob_distribution))
return(object$prob_distribution)

# return NA when not <distcrete>, <distribution> or character
return(NA)
}

#' Check if distribution in `<epiparameter>` is truncated
Expand Down Expand Up @@ -829,24 +811,7 @@ is_truncated <- function(x) {
"is_truncated only works for `<epiparameter> objects`" =
is_epiparameter(x)
)

# distcrete distributions cannot be truncated
if (inherits(x$prob_distribution, "distcrete")) {
return(FALSE)
}

# unparameterised objects cannot be truncated
# dont use is_parameterised due to infinite recursion
if (is.na(x$prob_distribution) || is.character(x$prob_distribution)) {
return(FALSE)
}

# use stats::family instead of epiparameter::family to check truncated
if (identical(stats::family(x$prob_distribution), "truncated")) {
return(TRUE)
} else {
return(FALSE)
}
return(identical(family(x), "truncated"))
}

#' Check if distribution in `<epiparameter>` is continuous
Expand Down Expand Up @@ -885,13 +850,13 @@ is_truncated <- function(x) {
#' is_continuous(ep)
is_continuous <- function(x) {
stopifnot(
"is_truncated only works for `<epiparameter> objects`" =
"is_continuous only works for `<epiparameter> objects`" =
is_epiparameter(x)
)
# get individual distributions out of mixture to check if continuous
# isTRUE to control for family returning NA for unparameterised
if (isTRUE(family(x) == "mixture")) {
fam <- .get_mixture_family(x)
fam <- .distributional_family(x$prob_distribution)
} else {
fam <- family(x)
}
Expand Down
35 changes: 35 additions & 0 deletions man/dot-distributional_family.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion man/family.epiparameter.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 29 additions & 3 deletions tests/testthat/test-accessors.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,24 @@ test_that("get_citation produces warnings with extra arguments", {
)
})

test_that(".get_mixture_family works as expected", {
test_that(".distributional_family works as expected for untransformed", {
ep <- suppressMessages(
epiparameter_db(
disease = "Ebola",
epi_name = "serial interval",
single_epiparameter = TRUE
)
)
expect_identical(.distributional_family(ep$prob_distribution), "gamma")
})

test_that(".distributional_family works as expected for transformed", {
ebola_si <- suppressMessages(
epiparameter_db(disease = "Ebola", epi_name = "serial interval")
)
ep <- aggregate(ebola_si)
expect_identical(
.get_mixture_family(ep),
.distributional_family(ep$prob_distribution),
rep("gamma", times = length(ebola_si))
)
incub <- suppressMessages(
Expand All @@ -148,7 +159,22 @@ test_that(".get_mixture_family works as expected", {
)
ep <- aggregate(incub)
expect_identical(
.get_mixture_family(ep),
.distributional_family(ep$prob_distribution),
c(rep("lnorm", 2), "gamma", rep("lnorm", 2))
)

ep <- epiparameter(
disease = "Ebola",
epi_name = "SI",
prob_distribution = create_prob_distribution(
prob_distribution = "lnorm",
prob_distribution_params = c(meanlog = 2, sdlog = 2),
truncation = 10
)
)
expect_identical(.distributional_family(ep$prob_distribution), "lnorm")
expect_identical(
.distributional_family(ep$prob_distribution, base_dist = FALSE),
"truncated"
)
})
16 changes: 15 additions & 1 deletion tests/testthat/test-epiparameter.R
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,21 @@ test_that("family works as expected for distributional truncated", {
truncation = 10
)
))
expect_identical(family(ep), "weibull")
expect_identical(family(ep), "truncated")
})

test_that("family works for distributional truncated with base_dist = TRUE", {
# message about missing citation suppressed
ep <- suppressMessages(epiparameter(
disease = "ebola",
epi_name = "incubation_period",
prob_distribution = create_prob_distribution(
prob_distribution = "weibull",
prob_distribution_params = c(shape = 1, scale = 1),
truncation = 10
)
))
expect_identical(family(ep, base_dist = TRUE), "weibull")
})

test_that("is_truncated works as expected for continuous distributions", {
Expand Down
Loading