Skip to content

Commit

Permalink
Merge pull request #112 from statasaurus/master
Browse files Browse the repository at this point in the history
Updating the Formatting of dist_mixture
  • Loading branch information
mitchelloharawild authored Jun 12, 2024
2 parents 116d8da + 9dbad9d commit 31ccd0c
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 9 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ Imports:
stats,
numDeriv,
utils,
lifecycle
lifecycle,
pillar
Suggests:
testthat (>= 2.1.0),
covr,
Expand Down
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ S3method(format,dist_wrap)
S3method(format,distribution)
S3method(format,hdr)
S3method(format,hilo)
S3method(format,pillar_distribution)
S3method(format,support_region)
S3method(generate,dist_bernoulli)
S3method(generate,dist_beta)
Expand Down Expand Up @@ -342,6 +343,7 @@ S3method(median,distribution)
S3method(parameters,dist_default)
S3method(parameters,dist_wrap)
S3method(parameters,distribution)
S3method(pillar_shaft,distribution)
S3method(print,dist_default)
S3method(quantile,dist_bernoulli)
S3method(quantile,dist_beta)
Expand Down Expand Up @@ -498,6 +500,10 @@ import(vctrs)
importFrom(generics,generate)
importFrom(lifecycle,deprecate_soft)
importFrom(lifecycle,deprecated)
importFrom(pillar,get_max_extent)
importFrom(pillar,new_ornament)
importFrom(pillar,new_pillar_shaft)
importFrom(pillar,pillar_shaft)
importFrom(stats,density)
importFrom(stats,family)
importFrom(stats,median)
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
* `support()` now shows whether the interval of support is open or
closed (@venpopov, #97)

## Improvements

* `dist_mixture()` now displays the components of the mixture when the output
width is sufficiently wide (@statasaurus, #112).

# distributional 0.4.0

## Breaking changes
Expand Down
27 changes: 27 additions & 0 deletions R/distribution.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,33 @@ format.distribution <- function(x, ...){
out
}

#' @importFrom pillar pillar_shaft new_pillar_shaft get_max_extent
#' @export
pillar_shaft.distribution <- function(x, ...) {
dist = format(x)
dist_min = format(x, width = 30)

pillar::new_pillar_shaft(
list(dist = dist,
dist_min = dist_min),
width = pillar::get_max_extent(dist),
min_width = pillar::get_max_extent(dist_min),
class = "pillar_distribution"
)
}

#' @export
#' @importFrom pillar new_ornament
format.pillar_distribution <- function(x, width, ...) {
if (get_max_extent(x$dist) <= width) {
ornament <- x$dist
} else {
ornament <- x$dist_min
}

pillar::new_ornament(ornament, align = "right")
}

#' @export
`dimnames<-.distribution` <- function(x, value){
attr(x, "vars") <- value
Expand Down
15 changes: 10 additions & 5 deletions R/mixture.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,16 @@ dist_mixture <- function(..., weights = numeric()){
}

#' @export
format.dist_mixture <- function(x, ...){
sprintf(
"mixture(n=%i)",
length(x[["dist"]])
)
format.dist_mixture <- function(x, width = getOption("width"), ...){
dists <- lapply(x[["dist"]], format) |>
unlist()

dist_info <- paste0(x[["w"]], "*", dists) |>
paste0(collapse = ", ")

long_dist <- paste0("mixture(", dist_info, ")")
short_dist <- paste0("mixture(n=", length(dists), ")")
ifelse(nchar(long_dist) <= width, long_dist, short_dist)
}

#' @export
Expand Down
7 changes: 4 additions & 3 deletions tests/testthat/test-mixture.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ test_that("Mixture of Normals", {
dist <- dist_mixture(dist_normal(0, 1), dist_normal(10, 4), weights = c(0.5, 0.5))

# format
expect_equal(format(dist), "mixture(n=2)")
expect_equal(format(dist), "mixture(0.5*N(0, 1), 0.5*N(10, 16))")

# quantiles
expect_equal(quantile(dist, 0.5), 2, tolerance = 1e-5)
Expand All @@ -27,7 +27,7 @@ test_that("Mixture of different distributions", {
dist <- dist_mixture(dist_normal(0, 1), dist_student_t(10), weights = c(0.3, 0.7))

# format
expect_equal(format(dist), "mixture(n=2)")
expect_equal(format(dist), "mixture(0.3*N(0, 1), 0.7*t(10, 0, 1))")

# quantiles
expect_equal(quantile(dist, 0.5), 0, tolerance = 1e-5)
Expand All @@ -52,7 +52,8 @@ test_that("Mixture of point masses", {
dist <- dist_mixture(dist_degenerate(1), dist_degenerate(2), dist_degenerate(3), weights = c(0.1, 0.2, 0.7))

# format
expect_equal(format(dist), "mixture(n=3)")
expect_equal(format(dist, width = 10), "mixture(n=3)")
expect_equal(format(dist), "mixture(0.1*1, 0.2*2, 0.7*3)")

# quantiles
expect_equal(quantile(dist, c(0, 0.1, 0.3, 1))[[1]], c(1, 1:3), tolerance = .Machine$double.eps^0.25)
Expand Down

0 comments on commit 31ccd0c

Please sign in to comment.