Skip to content

Commit

Permalink
Update brsmatch() and coxpsmatch() to handle NA rows (#17)
Browse files Browse the repository at this point in the history
* Update `brsmatch()` to handle NA rows
* Update `coxpsmatch()` to handle NA rows
* Minor internal code update to .brsmatch for cleaner code
* Add tests to check for this going forward
* Increment version and NEWS.md
  • Loading branch information
skent259 authored Feb 3, 2024
1 parent 2ad94f2 commit e068b02
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 16 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: rsmatch
Title: Matching Methods for Time-Varying Observational Studies
Version: 0.2.0.9000
Version: 0.2.0.9001
Authors@R: c(
person("Sean", "Kent", , "[email protected]", role = c("aut", "cre", "cph"),
comment = c(ORCID = "0000-0001-8697-9069")),
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# rsmatch (development version)
* Update `brsmatch()` and `coxpsmatch()` to handle NA rows via removing them

# rsmatch 0.2.0

Expand Down
38 changes: 24 additions & 14 deletions R/brsmatch.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ brsmatch <- function(
data[[trt_time]] <- data[[trt_time]] - 1
}

id_list <- unique(data[[id]]) # compute before any NA removal

# Remove NA rows except those in `trt_time` column, with a warning
na_action <- stats::na.omit(data[, setdiff(colnames(data), trt_time)])
na_rows <- attributes(na_action)$na.action
if (!is.null(na_rows)) {
rlang::warn(c(
"ID, time, and covariates should not have NA entries.",
i = paste("Removed", length(na_rows), "rows.")
))
data <- data[-na_rows, ]
}


if (!is.null(exact_match)) {
data_split <- split(data, data[, exact_match, drop = FALSE])
covariates <- setdiff(covariates, exact_match)
Expand Down Expand Up @@ -151,7 +165,7 @@ brsmatch <- function(
)
}

.output_pairs(matched_ids, id = id, id_list = unique(data[[id]]))
.output_pairs(matched_ids, id = id, id_list = id_list)
}

.brsmatch <- function(
Expand All @@ -168,27 +182,25 @@ brsmatch <- function(
optimizer <- options$optimizer
verbose <- options$verbose

if (verbose) {
rlang::inform("Computing distance from data...")
.print_if <- function(condition, message, ...) {
if (condition) {
rlang::inform(message, ...)
}
}

.print_if(verbose, "Computing distance from data...")
edges <- .compute_distances(data, id, time, trt_time, covariates, options)

bal <- NULL
if (balance) {
if (verbose) {
rlang::inform("Building balance columns from data...")
}
.print_if(verbose, "Building balance columns from data...")
bal <- .balance_columns(data, id, time, trt_time, balance_covariates)
}

if (verbose) {
rlang::inform("Constructing optimization model...")
}
.print_if(verbose, "Constructing optimization model...")
model <- .rsm_optimization_model(n_pairs, edges, bal, optimizer, verbose, balance)

if (verbose) {
rlang::inform("Preparing to run optimization model")
}
.print_if(verbose, "Preparing to run optimization model")
if (optimizer == "gurobi") {
res <- gurobi::gurobi(model, params = list(OutputFlag = 1 * verbose))
matches <- res$x[grepl("f", model$varnames)]
Expand All @@ -202,8 +214,6 @@ brsmatch <- function(
max = model$max,
control = list(verbose = verbose, presolve = TRUE)
)
# res <- with(model, Rglpk::Rglpk_solve_LP(obj, mat, dir, rhs, types = types, max = max,
# control = list(verbose = verbose, presolve = TRUE)))
matches <- res$solution[grepl("f", model$varnames)]
}

Expand Down
15 changes: 14 additions & 1 deletion R/coxpsmatch.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@ coxpsmatch <- function(
data[[trt_time]] <- as.numeric(data[[trt_time]])
}

id_list <- unique(data[[id]]) # compute before any NA removal

# Remove NA rows except those in `trt_time` column, with a warning
na_action <- stats::na.omit(data[, setdiff(colnames(data), trt_time)])
na_rows <- attributes(na_action)$na.action
if (!is.null(na_rows)) {
rlang::warn(c(
"ID, time, and covariates should not have NA entries.",
i = paste("Removed", length(na_rows), "rows.")
))
data <- data[-na_rows, ]
}

if (!is.null(exact_match)) {
balance_split <- split(data, data[, exact_match, drop = FALSE])
matches <- NULL
Expand All @@ -94,7 +107,7 @@ coxpsmatch <- function(
}

colnames(matches)[1:2] <- c("trt_id", "all_id")
return(.output_pairs(matches, id = id, id_list = unique(data[[id]])))
return(.output_pairs(matches, id = id, id_list = id_list))
}

#' Propensity Score Matching with Time-Dependent Covariates
Expand Down
40 changes: 40 additions & 0 deletions tests/testthat/test-brsmatch.R
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,43 @@ test_that("`brsmatch()` works for different input values.", {
options = list(time_lag = TRUE)
)
})


test_that("brsmatch works when some input are NA", {
df1 <- data.frame(
id = rep(1:3, each = 3),
time = rep(1:3, 3),
trt_time = rep(c(2, 3, NA), each = 3),
X1 = c(2, 2, 2, 3, 3, 3, 9, 9, 9),
X2 = rep(c("a", "a", "b"), each = 3),
X3 = c(9, 4, 5, 6, 7, 2, 3, 4, 8),
X4 = c(8, 9, 4, 5, 6, 7, 2, 3, 4)
)

check_for_glpk()
pairs1 <- brsmatch(n_pairs = 1, data = df1)

expect_equal(nrow(pairs1), length(unique(df1$id)))

# Check when trt type "all" is removed"
df2 <- df1
df2$X3[5:6] <- NA

pairs2 <- brsmatch(n_pairs = 1, data = df2) %>%
expect_warning("should not have NA")

expect_equal(nrow(pairs2), length(unique(df1$id)))

# Check when trt type "trt" is removed"
df3 <- df1
df3$X1[1:3] <- NA

pairs3 <- brsmatch(n_pairs = 1, data = df3) %>%
expect_warning("should not have NA")

expect_equal(nrow(pairs3), length(unique(df1$id)))

# NOTE: this still isn't graceful if the NA removes too many rows, but we
# can't hold everyone's hand all the time...
})

41 changes: 41 additions & 0 deletions tests/testthat/test-coxpsmatch.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,44 @@ test_that("`coxpsmatch()` works when there are no never-treated individuals", {
"ghost value"
)
})


test_that("coxpsmatch works when some input are NA", {
df1 <- data.frame(
id = rep(1:3, each = 3),
time = rep(1:3, 3),
trt_time = rep(c(2, 3, NA), each = 3),
X1 = c(2, 2, 2, 3, 3, 3, 9, 9, 9),
X2 = rep(c("a", "a", "b"), each = 3),
X3 = c(9, 4, 5, 6, 7, 2, 3, 4, 8),
X4 = c(8, 9, 4, 5, 6, 7, 2, 3, 4)
)

check_for_coxpsmatch_packages()

pairs1 <- coxpsmatch(n_pairs = 1, data = df1) %>%
expect_warning()

expect_equal(nrow(pairs1), length(unique(df1$id)))

# Check when trt type "all" is removed"
df2 <- df1
df2$X3[5:6] <- NA

pairs2 <- coxpsmatch(n_pairs = 1, data = df2) %>%
expect_warning("should not have NA")

expect_equal(nrow(pairs2), length(unique(df1$id)))

# Check when trt type "trt" is removed"
df3 <- df1
df3$X1[1:3] <- NA

pairs3 <- brsmatch(n_pairs = 1, data = df3) %>%
expect_warning("should not have NA")

expect_equal(nrow(pairs3), length(unique(df1$id)))

# NOTE: this still isn't graceful if the NA removes too many rows, but we
# can't hold everyone's hand all the time...
})

0 comments on commit e068b02

Please sign in to comment.