Skip to content

Commit

Permalink
Return original controls
Browse files Browse the repository at this point in the history
  • Loading branch information
wenjie2wang committed Jan 7, 2024
1 parent 3adaf96 commit 0e65bad
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 15 deletions.
8 changes: 4 additions & 4 deletions R/abclass.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ abclass.control <- function(lambda = NULL,
alpha = 1.0,
nlambda = 50L,
lambda_min_ratio = NULL,
grouped = TRUE,
penalty_factor = NULL,
grouped = TRUE,
group_penalty = c("lasso", "scad", "mcp"),
offset = NULL,
kappa_ratio = 0.9,
Expand All @@ -175,13 +175,13 @@ abclass.control <- function(lambda = NULL,
}
structure(list(
alpha = alpha,
lambda = null2num0(lambda),
lambda = lambda,
nlambda = as.integer(nlambda),
lambda_min_ratio = lambda_min_ratio,
penalty_factor = penalty_factor,
grouped = grouped,
group_penalty = group_penalty,
penalty_factor = null2num0(penalty_factor),
offset = null2mat0(offset),
offset = offset,
standardize = standardize,
maxit = as.integer(maxit),
epsilon = epsilon,
Expand Down
9 changes: 8 additions & 1 deletion R/abclass_engine.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@
c("lasso", "scad", "mcp"))
}
## process alignment
all_alignment <- c("fraction", "lambda")
if (is.numeric(alignment)) {
alignment <- as.integer(alignment[1L])
} else if (is.character(alignment)) {
all_alignment <- c("fraction", "lambda")
alignment <- match.arg(alignment, choices = all_alignment)
alignment <- match(alignment, all_alignment) - 1L
} else {
Expand All @@ -77,6 +77,9 @@
loss_id = loss_id,
penalty_id = penalty_id)
)
ctrl$lambda <- null2num0(ctrl$lambda)
ctrl$penalty_factor = null2num0(ctrl$penalty_factor)
ctrl$offset = null2mat0(ctrl$offset)
## arguments
call_list <- c(list(x = x, y = cat_y$y, control = ctrl))
fun_to_call <- if (is_x_sparse) {
Expand All @@ -93,6 +96,10 @@
res$control <- control
if (call_list$control$nfolds == 0L) {
res$cross_validation <- NULL
} else {
res$cross_validation$alignment <- all_alignment[
res$cross_validation$alignment
]
}
if (call_list$control$nstages == 0L) {
res$et <- NULL
Expand Down
9 changes: 8 additions & 1 deletion inst/tinytest/test-abclass.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@ train_y <- y[train_idx]
test_y <- y[- train_idx]

## logistic deviance loss
model1 <- abclass(train_x, train_y, nlambda = 5, grouped = FALSE)
model1 <- abclass(
x = train_x,
y = train_y,
nlambda = 5,
grouped = FALSE,
control = abclass.control(penalty_factor = runif(ncol(train_x)))
)

pred1 <- predict(model1, test_x, s = 5)
expect_true(mean(test_y == pred1) > 0.5)
expect_equivalent(dim(coef(model1, s = 5)), c(p + 1, k - 1))
Expand Down
10 changes: 5 additions & 5 deletions inst/tinytest/test-et.abclass.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ test_y <- y[- train_idx]

## without refit
model1 <- et.abclass(train_x, train_y, nstages = 2,
lambda_min_ratio = 1e-6, grouped = FALSE,
lambda_min_ratio = 1e-4, grouped = FALSE,
refit = FALSE)
expect_equivalent(dim(coef(model1)), c(p + 1, k - 1))

## with refit being TRUE
model1 <- et.abclass(train_x, train_y, nstages = 2,
lambda_min_ratio = 1e-6, grouped = TRUE,
lambda_min_ratio = 1e-4, grouped = TRUE,
refit = TRUE)
expect_equivalent(dim(coef(model1)), c(p + 1, k - 1))
pred1 <- predict(model1, test_x)
Expand All @@ -36,15 +36,15 @@ expect_true(mean(test_y == pred1) > 0.5)
## with reift as a list
## with cv
model1 <- et.abclass(train_x, train_y, nstages = 2,
lambda_min_ratio = 1e-6, grouped = TRUE,
lambda_min_ratio = 1e-4, grouped = TRUE,
refit = list(alpha = 0, nlambda = 10, nfolds = 3))
expect_equivalent(dim(coef(model1)), c(p + 1, k - 1))
pred1 <- predict(model1, test_x)
expect_true(mean(test_y == pred1) > 0.5)

## without cv
model1 <- et.abclass(train_x, train_y, nstages = 2,
lambda_min_ratio = 1e-6, grouped = TRUE,
lambda_min_ratio = 1e-4, grouped = TRUE,
refit = list(alpha = 0, nlambda = 10))
expect_equivalent(dim(coef(model1, selection = 10)), c(p + 1, k - 1))
pred1 <- predict(model1, test_x, s = 10)
Expand All @@ -55,7 +55,7 @@ expect_error(
et.abclass(train_x, train_y, penalty_factor = runif(ncol(train_x) + 1))
)

## with refit and penalty factors
## with penalty factors
gw <- runif(ncol(train_x))
model1 <- et.abclass(train_x, train_y, nstages = 2,
lambda_min_ratio = 1e-4,
Expand Down
8 changes: 4 additions & 4 deletions man/abclass.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 0e65bad

Please sign in to comment.