diff --git a/.gitignore b/.gitignore index 9ff0d41..e22386e 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ calmr.Rproj sketch.R docs revdep/ +.vscode/ diff --git a/DESCRIPTION b/DESCRIPTION index b7082d7..081ac80 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -42,7 +42,7 @@ Encoding: UTF-8 Language: en-US LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.1 Collate: 'HD2022.R' 'HDI2020.R' diff --git a/R/anccr_helpers.R b/R/anccr_helpers.R index 35c0409..4587e6e 100644 --- a/R/anccr_helpers.R +++ b/R/anccr_helpers.R @@ -47,23 +47,28 @@ set_reward_parameters <- function(parameters, rewards = c("US")) { transitions <- mapping$transitions[[trial_name]] period_funcs <- mapping$period_functionals[[trial_name]] # sample start of the trial - if (timings$use_exponential) { - new_ts <- min( - with(timings$trial_ts, max_ITI[trial == trial_name]), - stats::rexp( - 1, 1 / with(timings$trial_ts, mean_ITI[trial == trial_name]) - ), - timings$time_resolution - ) + if (!timings$sample_timings) { + # no sampling: use mean ITI + new_ts <- with(timings$trial_ts, mean_ITI[trial == trial_name]) } else { - new_ts <- stats::runif(1) * - with(timings$trial_ts, mean_ITI[trial == trial_name]) * - 0.4 + - with( - timings$trial_ts, - mean_ITI[trial == trial_name] - ) * 0.8 + if (timings$use_exponential) { + new_ts <- min( + with(timings$trial_ts, max_ITI[trial == trial_name]), + stats::rexp( + 1, 1 / with(timings$trial_ts, mean_ITI[trial == trial_name]) + ) + ) + } else { + new_ts <- stats::runif(1) * + with(timings$trial_ts, mean_ITI[trial == trial_name]) * + 0.4 + + with( + timings$trial_ts, + mean_ITI[trial == trial_name] + ) * 0.8 + } } + running_time <- running_time + new_ts for (p in seq_along(period_funcs)) { eventlog <- rbind( diff --git a/R/get_timings.R b/R/get_timings.R index 254e057..e06b8fe 100644 --- a/R/get_timings.R +++ b/R/get_timings.R @@ -97,20 +97,21 @@ get_timings <- function(design, model) { .default_global_timings <- function() { list( "use_exponential" = TRUE, - "time_resolution" = 0.5 + "time_resolution" = 0.5, + "sample_timings" = TRUE ) } .model_timings <- function(model) { timings_map <- list( "ANCCR" = list( - "global" = c("use_exponential"), + "global" = c("use_exponential", "sample_timings"), "transitions" = c("transition_delay"), "periods" = c(), "trials" = c("post_trial_delay", "mean_ITI", "max_ITI") ), "TD" = list( - "global" = c("use_exponential", "time_resolution"), + "global" = c("use_exponential", "time_resolution", "sample_timings"), "transitions" = c("transition_delay"), "periods" = c("stimulus_duration"), "trials" = c("post_trial_delay", "mean_ITI", "max_ITI") diff --git a/R/td_helpers.R b/R/td_helpers.R index 182b124..3ff3a4b 100644 --- a/R/td_helpers.R +++ b/R/td_helpers.R @@ -25,22 +25,27 @@ trial_name <- experience$tn[ti] transitions <- mapping$transitions[[trial_name]] period_funcs <- mapping$period_functionals[[trial_name]] - # sample start of the trial - if (timings$use_exponential) { - new_ts <- min( - with(timings$trial_ts, max_ITI[trial == trial_name]), - stats::rexp( - 1, 1 / with(timings$trial_ts, mean_ITI[trial == trial_name]) - ) - ) + # sample start of the trial if needed + if (!timings$sample_timings) { + # no sampling: use mean ITI + new_ts <- with(timings$trial_ts, mean_ITI[trial == trial_name]) } else { - new_ts <- stats::runif(1) * - with(timings$trial_ts, mean_ITI[trial == trial_name]) * - 0.4 + - with( - timings$trial_ts, - mean_ITI[trial == trial_name] - ) * 0.8 + if (timings$use_exponential) { + new_ts <- min( + with(timings$trial_ts, max_ITI[trial == trial_name]), + stats::rexp( + 1, 1 / with(timings$trial_ts, mean_ITI[trial == trial_name]) + ) + ) + } else { + new_ts <- stats::runif(1) * + with(timings$trial_ts, mean_ITI[trial == trial_name]) * + 0.4 + + with( + timings$trial_ts, + mean_ITI[trial == trial_name] + ) * 0.8 + } } running_time <- running_time + new_ts for (p in seq_along(period_funcs)) { diff --git a/tests/testthat/test-anccr_tests.R b/tests/testthat/test-anccr_tests.R index 5b5ea16..a06501b 100644 --- a/tests/testthat/test-anccr_tests.R +++ b/tests/testthat/test-anccr_tests.R @@ -28,6 +28,15 @@ test_that("can run without exponential", { ) }) +test_that("can run without sampling timings", { + nots <- tims + nots$sample_timings <- FALSE + expect_no_error(run_experiment(df, + parameters = pars, + timings = nots, model = "ANCCR" + )) +}) + test_that("can run with timed alpha", { talpha <- pars talpha$use_timed_alpha <- 1