Skip to content

Commit

Permalink
added option to avoid sampling timings for time-based models
Browse files Browse the repository at this point in the history
  • Loading branch information
victor-navarro committed Jun 7, 2024
1 parent bdcbf00 commit 985aa1f
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 34 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ calmr.Rproj
sketch.R
docs
revdep/
.vscode/
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Encoding: UTF-8
Language: en-US
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
Collate:
'HD2022.R'
'HDI2020.R'
Expand Down
35 changes: 20 additions & 15 deletions R/anccr_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,28 @@ set_reward_parameters <- function(parameters, rewards = c("US")) {
transitions <- mapping$transitions[[trial_name]]
period_funcs <- mapping$period_functionals[[trial_name]]
# sample start of the trial
if (timings$use_exponential) {
new_ts <- min(
with(timings$trial_ts, max_ITI[trial == trial_name]),
stats::rexp(
1, 1 / with(timings$trial_ts, mean_ITI[trial == trial_name])
),
timings$time_resolution
)
if (!timings$sample_timings) {
# no sampling: use mean ITI
new_ts <- with(timings$trial_ts, mean_ITI[trial == trial_name])
} else {
new_ts <- stats::runif(1) *
with(timings$trial_ts, mean_ITI[trial == trial_name]) *
0.4 +
with(
timings$trial_ts,
mean_ITI[trial == trial_name]
) * 0.8
if (timings$use_exponential) {
new_ts <- min(
with(timings$trial_ts, max_ITI[trial == trial_name]),
stats::rexp(
1, 1 / with(timings$trial_ts, mean_ITI[trial == trial_name])
)
)
} else {
new_ts <- stats::runif(1) *
with(timings$trial_ts, mean_ITI[trial == trial_name]) *
0.4 +
with(
timings$trial_ts,
mean_ITI[trial == trial_name]
) * 0.8
}
}

running_time <- running_time + new_ts
for (p in seq_along(period_funcs)) {
eventlog <- rbind(
Expand Down
7 changes: 4 additions & 3 deletions R/get_timings.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,21 @@ get_timings <- function(design, model) {
.default_global_timings <- function() {
list(
"use_exponential" = TRUE,
"time_resolution" = 0.5
"time_resolution" = 0.5,
"sample_timings" = TRUE
)
}

.model_timings <- function(model) {
timings_map <- list(
"ANCCR" = list(
"global" = c("use_exponential"),
"global" = c("use_exponential", "sample_timings"),
"transitions" = c("transition_delay"),
"periods" = c(),
"trials" = c("post_trial_delay", "mean_ITI", "max_ITI")
),
"TD" = list(
"global" = c("use_exponential", "time_resolution"),
"global" = c("use_exponential", "time_resolution", "sample_timings"),
"transitions" = c("transition_delay"),
"periods" = c("stimulus_duration"),
"trials" = c("post_trial_delay", "mean_ITI", "max_ITI")
Expand Down
35 changes: 20 additions & 15 deletions R/td_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,27 @@
trial_name <- experience$tn[ti]
transitions <- mapping$transitions[[trial_name]]
period_funcs <- mapping$period_functionals[[trial_name]]
# sample start of the trial
if (timings$use_exponential) {
new_ts <- min(
with(timings$trial_ts, max_ITI[trial == trial_name]),
stats::rexp(
1, 1 / with(timings$trial_ts, mean_ITI[trial == trial_name])
)
)
# sample start of the trial if needed
if (!timings$sample_timings) {
# no sampling: use mean ITI
new_ts <- with(timings$trial_ts, mean_ITI[trial == trial_name])
} else {
new_ts <- stats::runif(1) *
with(timings$trial_ts, mean_ITI[trial == trial_name]) *
0.4 +
with(
timings$trial_ts,
mean_ITI[trial == trial_name]
) * 0.8
if (timings$use_exponential) {
new_ts <- min(
with(timings$trial_ts, max_ITI[trial == trial_name]),
stats::rexp(
1, 1 / with(timings$trial_ts, mean_ITI[trial == trial_name])
)
)
} else {
new_ts <- stats::runif(1) *
with(timings$trial_ts, mean_ITI[trial == trial_name]) *
0.4 +
with(
timings$trial_ts,
mean_ITI[trial == trial_name]
) * 0.8
}
}
running_time <- running_time + new_ts
for (p in seq_along(period_funcs)) {
Expand Down
9 changes: 9 additions & 0 deletions tests/testthat/test-anccr_tests.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ test_that("can run without exponential", {
)
})

test_that("can run without sampling timings", {
nots <- tims
nots$sample_timings <- FALSE
expect_no_error(run_experiment(df,
parameters = pars,
timings = nots, model = "ANCCR"
))
})

test_that("can run with timed alpha", {
talpha <- pars
talpha$use_timed_alpha <- 1
Expand Down

0 comments on commit 985aa1f

Please sign in to comment.