-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add ATE estimate under Adaptive Design (ADATE)
- Loading branch information
1 parent
43af29c
commit 76f2b46
Showing
3 changed files
with
308 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
#' Average Treatment Effect under Adaptive Design | ||
#' | ||
#' Parameter definition for Average Treatment Effect under Adaptive Design (ADATE): $P_{n,W}[E(Y|A=1,W)-E(Y|A=0,W)]$. Currently supports adaptive design for BINARY intervention nodes. | ||
#' @importFrom R6 R6Class | ||
#' @importFrom uuid UUIDgenerate | ||
#' @importFrom methods is | ||
#' @family Parameters | ||
#' @keywords data | ||
#' | ||
#' @return \code{Param_base} object | ||
#' | ||
#' @format \code{\link{R6Class}} object. | ||
#' | ||
#' @section Constructor: | ||
#' \code{define_param(Param_ATT, observed_likelihood, intervention_list, ..., outcome_node)} | ||
#' | ||
#' \describe{ | ||
#' \item{\code{observed_likelihood}}{A \code{\link{Likelihood}} corresponding to the observed likelihood | ||
#' } | ||
#' \item{\code{intervention_list_treatment}}{A list of objects inheriting from \code{\link{LF_base}}, representing the treatment intervention. | ||
#' } | ||
#' \item{\code{intervention_list_control}}{A list of objects inheriting from \code{\link{LF_base}}, representing the control intervention. | ||
#' } | ||
#' \item{\code{g_treat}}{vector, the actual probability of A that corresponds to treatment | ||
#' } | ||
|
||
#' \item{\code{...}}{Not currently used. | ||
#' } | ||
#' \item{\code{outcome_node}}{character, the name of the node that should be treated as the outcome | ||
#' } | ||
#' } | ||
#' | ||
|
||
#' @section Fields: | ||
#' \describe{ | ||
#' \item{\code{cf_likelihood_treatment}}{the counterfactual likelihood for the treatment | ||
#' } | ||
#' \item{\code{cf_likelihood_control}}{the counterfactual likelihood for the control | ||
#' } | ||
#' \item{\code{g_treat}}{the actual probability of A that corresponds to treatment | ||
#' } | ||
#' \item{\code{intervention_list_treatment}}{A list of objects inheriting from \code{\link{LF_base}}, representing the treatment intervention | ||
#' } | ||
#' \item{\code{intervention_list_control}}{A list of objects inheriting from \code{\link{LF_base}}, representing the control intervention | ||
#' } | ||
#' } | ||
#' @export | ||
Param_ADATE <- R6Class( | ||
classname = "Param_ADATE", | ||
portable = TRUE, | ||
class = TRUE, | ||
inherit = Param_base, | ||
public = list( | ||
initialize = function(observed_likelihood, intervention_list_treatment, intervention_list_control, g_treat, outcome_node = "Y") { | ||
super$initialize(observed_likelihood, list(), outcome_node) | ||
if (!is.null(observed_likelihood$censoring_nodes[[outcome_node]])) { | ||
# add delta_Y=0 to intervention lists | ||
outcome_censoring_node <- observed_likelihood$censoring_nodes[[outcome_node]] | ||
censoring_intervention <- define_lf(LF_static, outcome_censoring_node, value = 1) | ||
intervention_list_treatment <- c(intervention_list_treatment, censoring_intervention) | ||
intervention_list_control <- c(intervention_list_control, censoring_intervention) | ||
} | ||
private$.g_treat <- g_treat | ||
|
||
private$.cf_likelihood_treatment <- CF_Likelihood$new(observed_likelihood, intervention_list_treatment) | ||
private$.cf_likelihood_control <- CF_Likelihood$new(observed_likelihood, intervention_list_control) | ||
}, | ||
clever_covariates = function(tmle_task = NULL, fold_number = "full") { | ||
if (is.null(tmle_task)) { | ||
tmle_task <- self$observed_likelihood$training_task | ||
} | ||
|
||
intervention_nodes <- union(names(self$intervention_list_treatment), names(self$intervention_list_control)) | ||
|
||
pA <- self$observed_likelihood$get_likelihoods(tmle_task, intervention_nodes, fold_number) | ||
cf_pA_treatment <- self$cf_likelihood_treatment$get_likelihoods(tmle_task, intervention_nodes, fold_number) | ||
cf_pA_control <- self$cf_likelihood_control$get_likelihoods(tmle_task, intervention_nodes, fold_number) | ||
|
||
g_treat <- self$g_treat | ||
|
||
HA_treatment <- cf_pA_treatment / g_treat | ||
HA_control <- cf_pA_control / (1 - g_treat) | ||
|
||
# collapse across multiple intervention nodes | ||
if (!is.null(ncol(HA_treatment)) && ncol(HA_treatment) > 1) { | ||
HA_treatment <- apply(HA_treatment, 1, prod) | ||
} | ||
|
||
# collapse across multiple intervention nodes | ||
if (!is.null(ncol(HA_control)) && ncol(HA_control) > 1) { | ||
HA_control <- apply(HA_control, 1, prod) | ||
} | ||
|
||
HA <- HA_treatment - HA_control | ||
|
||
HA <- bound(HA, c(-40, 40)) | ||
return(list(Y = HA)) | ||
}, | ||
estimates = function(tmle_task = NULL, fold_number = "full") { | ||
if (is.null(tmle_task)) { | ||
tmle_task <- self$observed_likelihood$training_task | ||
} | ||
|
||
intervention_nodes <- union(names(self$intervention_list_treatment), names(self$intervention_list_control)) | ||
|
||
# clever_covariates happen here (for this param) only, but this is repeated computation | ||
HA <- self$clever_covariates(tmle_task, fold_number)[[self$outcome_node]] | ||
|
||
|
||
# todo: make sure we support updating these params | ||
pA <- self$observed_likelihood$get_likelihoods(tmle_task, intervention_nodes, fold_number) | ||
cf_pA_treatment <- self$cf_likelihood_treatment$get_likelihoods(tmle_task, intervention_nodes, fold_number) | ||
cf_pA_control <- self$cf_likelihood_control$get_likelihoods(tmle_task, intervention_nodes, fold_number) | ||
|
||
# todo: extend for stochastic | ||
cf_task_treatment <- self$cf_likelihood_treatment$enumerate_cf_tasks(tmle_task)[[1]] | ||
cf_task_control <- self$cf_likelihood_control$enumerate_cf_tasks(tmle_task)[[1]] | ||
|
||
Y <- tmle_task$get_tmle_node(self$outcome_node, impute_censoring = TRUE) | ||
|
||
EY <- self$observed_likelihood$get_likelihood(tmle_task, self$outcome_node, fold_number) | ||
EY1 <- self$observed_likelihood$get_likelihood(cf_task_treatment, self$outcome_node, fold_number) | ||
EY0 <- self$observed_likelihood$get_likelihood(cf_task_control, self$outcome_node, fold_number) | ||
|
||
psi <- mean(EY1 - EY0) | ||
IC <- HA * (Y - EY) | ||
|
||
result <- list(psi = psi, IC = IC) | ||
return(result) | ||
} | ||
), | ||
active = list( | ||
name = function() { | ||
# param_form <- sprintf("ATE[%s_{%s}-%s_{%s}]", self$outcome_node, self$cf_likelihood_treatment$name, self$outcome_node, self$cf_likelihood_control$name) | ||
param_form <- "ADATE[Y]" | ||
return(param_form) | ||
}, | ||
g_treat = function(){ | ||
return (private$.g_treat) | ||
}, | ||
cf_likelihood_treatment = function() { | ||
return(private$.cf_likelihood_treatment) | ||
}, | ||
cf_likelihood_control = function() { | ||
return(private$.cf_likelihood_control) | ||
}, | ||
intervention_list_treatment = function() { | ||
return(self$cf_likelihood_treatment$intervention_list) | ||
}, | ||
intervention_list_control = function() { | ||
return(self$cf_likelihood_control$intervention_list) | ||
}, | ||
update_nodes = function() { | ||
return(c(self$outcome_node)) | ||
} | ||
), | ||
private = list( | ||
.type = "ADATE", | ||
.g_treat = NULL, | ||
.cf_likelihood_treatment = NULL, | ||
.cf_likelihood_control = NULL, | ||
.supports_outcome_censoring = TRUE | ||
) | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
#' Defines a TML Estimator for Average Treatment Effect under Adaptive Design | ||
#' | ||
#' @importFrom R6 R6Class | ||
#' @importFrom tmle3 tmle3_Spec define_lf tmle3_Update Targeted_Likelihood | ||
#' Param_TSM point_tx_likelihood | ||
#' | ||
#' @export | ||
tmle3_Spec_ADATE <- R6::R6Class( | ||
classname = "tmle3_Spec_ADATE", | ||
portable = TRUE, | ||
class = TRUE, | ||
inherit = tmle3_Spec, | ||
public = list( | ||
initialize = function(treatment_level, control_level, g_treat, g_adapt, ...) { | ||
super$initialize( | ||
treatment_level = treatment_level, | ||
control_level = control_level, | ||
g_treat = g_treat, ... | ||
) | ||
}, | ||
make_params = function(tmle_task, likelihood, ...) { | ||
g_treat <-self$options$g_treat | ||
if (!(is.vector(g_treat) & | ||
tmle_task$nrow == length(g_treat))) { | ||
msg <- paste("`g_treat` must be vectors", | ||
"with a length of `tmle_task$nrow`") | ||
stop(msg) | ||
} | ||
|
||
treatment_value <- self$options$treatment_level | ||
control_value <- self$options$control_level | ||
A_levels <- tmle_task$npsem[["A"]]$variable_type$levels | ||
if (!is.null(A_levels)) { | ||
treatment_value <- factor(treatment_value, levels = A_levels) | ||
control_value <- factor(control_value, levels = A_levels) | ||
} | ||
treatment <- define_lf(LF_static, "A", value = treatment_value) | ||
control <- define_lf(LF_static, "A", value = control_value) | ||
adate <- Param_ADATE$new(likelihood, treatment, control, g_treat) | ||
tmle_params <- list(adate) | ||
return(tmle_params) | ||
}, | ||
make_updater = function() { | ||
updater <- tmle3_Update$new(cvtmle = TRUE) | ||
} | ||
), | ||
active = list(), | ||
private = list() | ||
) | ||
|
||
################################################################################ | ||
|
||
#' Mean Outcome under a Candidate Adaptive Design | ||
#' | ||
#' O = (W, A, Y) | ||
#' W = Covariates | ||
#' A = Treatment (binary or categorical) | ||
#' Y = Outcome (binary or bounded continuous) | ||
#' | ||
#' @importFrom sl3 make_learner Lrnr_mean | ||
#' @param treatment_level the level of A that corresponds to treatment | ||
#' @param control_level the level of A that corresponds to a control or reference level | ||
#' @param g_treat the actual probability of A that corresponds to treatment | ||
#' @export | ||
tmle_ADATE <- function(treatment_level, control_level, g_treat) { | ||
# TODO: unclear why this has to be in a factory function | ||
tmle3_Spec_ADATE$new(treatment_level, control_level, g_treat) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
context("Test the TML Estimator for the Mean Outcome under a Counterfactual Adaptive Design") | ||
|
||
library(sl3) | ||
library(tmle3) | ||
library(uuid) | ||
library(assertthat) | ||
library(data.table) | ||
library(future) | ||
|
||
set.seed(1234) | ||
|
||
## simulate simple data for TML for adaptive design | ||
n_obs <- 5000 # number of observations | ||
|
||
## baseline covariates -- simple, binary | ||
W <- rnorm(n_obs, 0, 1) | ||
|
||
## create treatment based on baseline W | ||
g_fair <- rep(0.5, n_obs) | ||
g_treatment <- g_fair / 2 * (W > 0.5) + (1 - g_fair / 2) * (W < 0.5) | ||
A <- sapply(g_treatment, rbinom, n = 1, size = 1) | ||
|
||
EY1 <- W + W^2 | ||
EY0 <- W | ||
Y <- A*EY1 + (1-A)*EY0 + rnorm(n_obs, mean = 0, sd = 1) | ||
|
||
## organize data and nodes for tmle3 | ||
data <- data.table(W, A, Y) | ||
node_list <- list(W = "W", A = "A", Y = "Y") | ||
|
||
# learners used for conditional expectation regression (e.g., outcome) | ||
mean_lrnr <- Lrnr_mean$new() | ||
glm_lrnr <- Lrnr_glm$new() | ||
sl_lrnr <- Lrnr_sl$new( | ||
learners = list(mean_lrnr, glm_lrnr), | ||
metalearner = Lrnr_nnls$new() | ||
) | ||
learner_list <- list(A = mean_lrnr, Y = sl_lrnr) | ||
|
||
# Test 1 | ||
## Define tmle_spec | ||
tmle_spec <- tmle3_Spec_ADATE$new( | ||
treatment_level = 1, | ||
control_level = 0, | ||
g_treat = g_treatment) | ||
|
||
## Define tmle task | ||
tmle_task <- tmle_spec$make_tmle_task(data, node_list) | ||
|
||
## Make initial likelihood | ||
initial_likelihood <- tmle_spec$make_initial_likelihood( | ||
tmle_task, | ||
learner_list | ||
) | ||
|
||
## Create targeted_likelihood object | ||
targeted_likelihood <- Targeted_Likelihood$new(initial_likelihood) | ||
|
||
## Define tmle param | ||
tmle_params <- tmle_spec$make_params(tmle_task, targeted_likelihood) | ||
|
||
## Run TMLE | ||
tmle_fit <- fit_tmle3( | ||
tmle_task, targeted_likelihood, tmle_params, | ||
targeted_likelihood$updater | ||
) | ||
tmle_fit | ||
|
||
## Truth | ||
truth_1 <- mean(EY1-EY0) | ||
|
||
test_that("TMLE CI includes truth", { | ||
expect_lte(abs(truth_1 - tmle_fit$summary$tmle_est), tmle_fit$summary$se * 1.96) | ||
}) | ||
|