-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.R
211 lines (176 loc) · 6.82 KB
/
run.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
library(clustermq)
library(dplyr)
library(magrittr)
# `model_code`: String of the .stan file
# `data`: List object to be passed to Stan. In most cases this will come from
# the configuration generated by Covidestim.
# `seedVal`: The seed for the optimizer
f <- function(model_code, data) {
# `optimizing` comes from `rstan`, assumption is that `rstan` has already
# been loaded.
optimizing(
# Compile the model on the cluster. Don't save the .dso, this guarantees
# that it will not be reused.
object = stan_model(save_dso=F, model_code=model_code),
data = data,
algorithm = "BFGS",
iter = 1e4,
verbose = T,
as_vector = F,
)
}
fMultiple <- function(
model_code,
data,
tries = 10,
iter = 6e3,
timeout = 5*60,
sampler = FALSE # TRUE requires at least 3 cores available
) {
rstan_options(auto_write = T)
model <- stan_model(model_code = model_code)
if(sampler == TRUE) {
startTime <- Sys.time()
rstan::sampling(
object = model,
data = data,
control = list(adapt_delta = .98, max_treedepth = 14),
chains = 3,
iter = 2000,
thin = 1,
warmup = round((2/3)*2000)) -> result
result <- rstan::summary(result)$summary
endTime <- Sys.time()
message(glue::glue(
'Finished try #{i} in {dt} with exit code {ec}',
dt = prettyunits::pretty_dt(endTime - startTime),
ec = result$return_code
));
return(result)
}
runOptimizerWithSeed <- function(i) {
startTime <- Sys.time()
rstan::optimizing(
object = model,
data = data,
algorithm = "BFGS",
iter = iter,
as_vector = FALSE # Otherwise you get a sloppy list structure
) -> result
endTime <- Sys.time()
message(glue::glue(
'Finished try #{i} in {dt} with exit code {ec}',
dt = prettyunits::pretty_dt(endTime - startTime),
ec = result$return_code
));
result
}
# This function will return NULL when there is a timeout
runOptimizerWithSeedInTime <- function(i, timeout)
tryCatch(
R.utils::withTimeout(runOptimizerWithSeed(i), timeout = timeout),
error = function(c) {
message(glue::glue('Abandoned try #{i} due to timeout'))
NULL
}
)
results <- NULL
# Return the first time we get a non-obviously-bad result from BFGS, to save
# time.
for (i in 1:tries) {
r <- runOptimizerWithSeedInTime(i, timeout)
# Return code of 0 indicates success for `rstan::optimizing`. This is just
# a standard UNIX return code b/c `rstan::optimizing` calls into CmdStan.
#
# Timed-out runs return NULL.
#
# In theory the log posterior could be infinite (likely, -Infinity), which
# wouldn't be valid but would technically be the maximum value. Exclude
# runs which have these values.
# if (!is.null(r) && (r$return_code[1] == 0) && !is.infinite(r$value)) {
# message("[#{i}]: Good result!")
# result <- r # Commit the result as the final result
# break
# }
results[[i]] <- r
}
successful_results <-
purrr::discard(results, is.null) %>% # Removes timed-out runs
purrr::keep(., ~.$return_code == 0) # Removes >0 return-val runs
if (length(successful_results) == 0)
stop("All BFGS runs timed out or failed!")
# Extract the mode of the posterior from the results that didn't time out
# and didn't return an error code of 70
opt_vals <- purrr::map_dbl(successful_results, 'value')
# In theory the log posterior could be infinite (likely, -Infinity), which
# wouldn't be valid but would technically be the maximum value. Throw an
# error in this case.
if (is.infinite(max(opt_vals)))
stop(glue::glue(
'The value of the log posterior was infinite for these runs:\n{runs}',
runs = which(is.infinite(opt_vals) & opt_vals > 0)
))
# The first successful result which has `opt_val` equal to the maximum
# `opt_val` is the result that will be returned too the user. Note that it's
# unlikely that there will be more that one trajectory with the same
# `opt_val`. However, if this is the case, the first of these results will
# be returned
result <- successful_results[which(opt_vals == max(opt_vals))][[1]]
#
# if (is.null(result)) # Branch only occurs if no good result was I.D.'d.
# stop("All BFGS runs timed out or failed or had Inf log-posteriors!")
result
}
# Use ClusterMQ to connect to the cluster, compile the model, and run it.
# This function can easily be modified to perform various experiments. See
# the docs: `?clustermq::Q`. Worker logs will be found in `~/`.
run <- function(f, tests, codePath, jobs_per_worker = 4, time_per_run = 12, cores = 1) {
result <- Q(
f,
data = tests$config,
const = list(model_code = read_file(codePath)),
job_size = jobs_per_worker,
log_worker = T,
pkgs = c('rstan', 'glue', 'prettyunits'),
fail_on_error = F,
template = list(
time = jobs_per_worker * time_per_run,
cores = cores
)
)
mutate(tests, result = result)
}
states <- state.name
counties <- c("24027", "06023", "41039", "11001",
"47007","16035","17019","40017","01027","46021","51630","32009","54081","01085","06047","55075","22053","48173","13001","01041","04009","13185",
"38047","16077","37057","55031","53027","05047","18095","17009","02185","22089","18099","01009","46102","21025","36037","28005","48075",
"37097","13193","46093","48255","29063","51171","19055","13197","40101","01029","18117","51081","05065","39069","13133","48105","17053",
"26131","16057","48051","45047","33013","55001","34041","28079","40141","02100","45077","13033","40099","29011","19071","31185","35013",
"53007","28139","20141","31023","22095","47057","37121","55071","35051","48097","48445","45003","22083","20029","48285","18013","17105",
"21213","51181","29201","26115","01043","35031","48125","19063","54013","40115","08077","50005","45035","20119","40069")
codePath <- "../covidestim/inst/stan/stan_program_default.stan"
ncores <- 190
map(
states,
~mutate(testset, region = ., d = list(getInputs(.)))
) %>% bind_rows %>% as_tibble %>% getConfigs -> tests_states
map(
counties,
~mutate(testset, region = ., d = list(getInputs(.)))
) %>% bind_rows %>% as_tibble %>% getConfigs -> tests_counties
test_results_states <- run(
f = fMultiple,
tests = tests_states,
codePath = codePath,
jobs_per_worker = ceiling(nrow(tests_states)/ncores),
time_per_run = 20
)
test_results_counties <- run(
f = fMultiple,
tests = tests_counties,
codePath = codePath,
jobs_per_worker = ceiling(nrow(tests_counties)/ncores),
time_per_run = 20
)
saveRDS(test_results_states, 'test_results_states.RDS')
saveRDS(test_results_counties, 'test_results_counties.RDS')