| Title: | Optimal Transport Weights for Causal Inference |
|---|---|
| Description: | Uses optimal transport distances to find probabilistic matching estimators for causal inference. These methods are described in Dunipace, Eric (2021) <doi:10.48550/arXiv.2109.01991>. The package will build the weights, estimate treatment effects, and calculate confidence intervals via the methods described in the paper. The package also supports several other methods as described in the help files. |
| Authors: | Eric Dunipace [aut, cre]
|
| Maintainer: | Eric Dunipace <[email protected]> |
| License: | GPL (==3.0) |
| Version: | 1.0.4 |
| Built: | 2026-06-06 06:56:40 UTC |
| Source: | https://github.com/ericdunipace/causalot |
Barycentric Projection outcome estimation
barycentric_projection( formula, data, weights, separate.samples.on = "z", penalty = NULL, cost_function = NULL, p = 2, debias = FALSE, cost.online = "auto", diameter = NULL, niter = 1000L, tol = 1e-07, ... )barycentric_projection( formula, data, weights, separate.samples.on = "z", penalty = NULL, cost_function = NULL, p = 2, debias = FALSE, cost.online = "auto", diameter = NULL, niter = 1000L, tol = 1e-07, ... )
formula |
A formula object specifying the outcome and covariates. |
data |
A data.frame of the data to use in the model. |
weights |
Either a vector of weights, one for each observations, or an object of class causalWeights. |
separate.samples.on |
The variable in the data denoting the treatment indicator. How to separate samples for the optimal transport calculation |
penalty |
The penalty parameter to use in the optimal transport calculation. By default it is |
cost_function |
A user supplied cost function. If supplied, must take arguments |
p |
The power to raise the cost function. Default is 2.0. For user supplied cost functions, the cost will not be raised by this power unless the user so specifies. |
debias |
Should debiased barycentric projections be used? See details. |
cost.online |
Should an online cost algorithm be used? Default is "auto", which selects an online cost algorithm when the sample size in each group specified by |
diameter |
The diameter of the covariate space, if known. |
niter |
The maximum number of iterations to run the optimal transport problems |
tol |
The tolerance for convergence of the optimal transport problems |
... |
Not used at this time. |
The barycentric projection uses the dual potentials from the optimal transport distance between the two samples to calculate projections from one sample into another. For example, in the sample of controls, we may wish to know their outcome had they been treated. In general, we then seek to minimize
where is the primal solution from the optimal transport problem.
These values can also be de-biased using the solutions from running an optimal transport problem of one sample against itself. Details are listed in Pooladian et al. (2022) https://arxiv.org/abs/2202.08919.
An object of class "bp" which is a list with slots:
potentials The dual potentials from calculating the optimal transport distance
penalty The value of the penalty parameter used in calculating the optimal transport distance
cost_function The cost function used to calculate the distances between units.
cost_alg A character vector denoting if an distance, a squared euclidean distance, or other distance metric was used.
p The power to which the cost matrix was raised if not using a user supplied cost function.
debias Whether barycentric projections should be debiased.
tensorized TRUE/FALSE denoting wether to use offline cost matrices.
data An object of class dataHolder with the data used to calculate the optimal transport distance.
y_a The outcome vector in the first sample.
y_b The outcome vector in the second sample.
x_a The covariate matrix in the first sample.
x_b The covariate matrix in the second sample.
a The empirical measure in the first sample.
b The empirical measure in the second sample.
terms The terms object from the formula.
if(torch::torch_is_installed()) { set.seed(23483) n <- 2^5 pp <- 6 overlap <- "low" design <- "A" estimate <- "ATT" power <- 2 data <- causalOT::Hainmueller$new(n = n, p = pp, design = design, overlap = overlap) data$gen_data() weights <- causalOT::calc_weight(x = data, z = NULL, y = NULL, estimand = estimate, method = "NNM") df <- data.frame(y = data$get_y(), z = data$get_z(), data$get_x()) fit <- causalOT::barycentric_projection(y ~ ., data = df, weight = weights, separate.samples.on = "z", niter = 2) inherits(fit, "bp") }if(torch::torch_is_installed()) { set.seed(23483) n <- 2^5 pp <- 6 overlap <- "low" design <- "A" estimate <- "ATT" power <- 2 data <- causalOT::Hainmueller$new(n = n, p = pp, design = design, overlap = overlap) data$gen_data() weights <- causalOT::calc_weight(x = data, z = NULL, y = NULL, estimand = estimate, method = "NNM") df <- data.frame(y = data$get_y(), z = data$get_z(), data$get_x()) fit <- causalOT::barycentric_projection(y ~ ., data = df, weight = weights, separate.samples.on = "z", niter = 2) inherits(fit, "bp") }
Estimate causal weights
calc_weight( x, z, estimand = c("ATC", "ATT", "ATE"), method = supported_methods(), options = NULL, weights = NULL, ... )calc_weight( x, z, estimand = c("ATC", "ATT", "ATE"), method = supported_methods(), options = NULL, weights = NULL, ... )
x |
A numeric matrix of covariates. You can also pass an object of class dataHolder or DataSim, which will make argument |
z |
A binary treatment indicator. |
estimand |
The estimand of interest. One of "ATT","ATC", or "ATE". |
method |
The method to estimate the causal weights. Must be one of the methods returned by |
options |
The options for the solver. Specific options depend on the solver you will be using and you can use the solver specific options functions as detailed below.. |
weights |
The sample weights. Should be |
... |
Not used at this time. |
We detail some of the particulars of the function arguments below.
This is the.main method of the package. This method relies on various solvers depending on the particular options chosen. Please see cotOptions() for more details.
This is equivalent to COT with an infinite penalty parameter, options(lambda = Inf). Uses the same solver and options as COT, cotOptions().
This is equivalent to COT with a penalty parameter = 0, options(lambda = 0). Uses the same solver and options as COT, cotOptions().
The SCM method is equivalent to an OT problem from a different angle. See scmOptions().
This method balances chosen functions of the covariates specified in the data argument, x. See entBWOptions() for more details. Hainmueller (2012).
Entropy Balancing Weights with a different penalty parameter, proposed by Zuizarreta (2012). See sbwOptions() for more details
The CBPS method of Imai and Ratkovic. Options argument is passed to the function CBPS().
The main methods historically for implementing inverse probability weights. Options are passed directly to the glm function from R.
An object of class causalWeights
set.seed(23483) n <- 2^5 p <- 6 #### get data #### data <- Hainmueller$new(n = n, p = p) data$gen_data() x <- data$get_x() z <- data$get_z() if (torch::torch_is_installed()) { # estimate weights weights <- calc_weight(x = x, z = z, estimand = "ATE", method = "COT", options = list(lambda = 0)) #we can also use the dataSim object directly weightsDS <- calc_weight(x = data, z = NULL, estimand = "ATE", method = "COT", options = list(lambda = 0)) all.equal(weights@w0, weightsDS@w0) all.equal(weights@w1, weightsDS@w1) }set.seed(23483) n <- 2^5 p <- 6 #### get data #### data <- Hainmueller$new(n = n, p = p) data$gen_data() x <- data$get_x() z <- data$get_z() if (torch::torch_is_installed()) { # estimate weights weights <- calc_weight(x = x, z = z, estimand = "ATE", method = "COT", options = list(lambda = 0)) #we can also use the dataSim object directly weightsDS <- calc_weight(x = data, z = NULL, estimand = "ATE", method = "COT", options = list(lambda = 0)) all.equal(weights@w0, weightsDS@w0) all.equal(weights@w1, weightsDS@w1) }
causalWeights class
This object is returned by the calc_weight function in this package. The slots can be accessed as any S4 object. There is no publicly accessible constructor function.
w0A slot with the weights for the control group with entries. Weights sum to 1.
w1The weights for the treated group with entries. Weights sum to 1.
estimandA character denoting the estimand targeted by the weights. One of "ATT","ATC", or "ATE".
infoA slot to store a variety of info for inference. Currently under development.
methodA character denoting the method used to estimate the weights.
penaltyA list or the selected penalty parameters, if relevant.
dataThe dataHolder object containing the original data.
callThe call used to construct the weights.
Extract treatment effect estimate
## S3 method for class 'causalEffect' coef(object, ...)## S3 method for class 'causalEffect' coef(object, ...)
object |
An object of class causalEffect |
... |
Not used |
A number corresponding to the estimated treatment effect
# set-up data set.seed(1234) data <- Hainmueller$new() data$gen_data() # calculate quantities weight <- calc_weight(data, method = "Logistic", estimand = "ATE") tx_eff <- estimate_effect(causalWeights = weight) all.equal(coef(tx_eff), c(estimate = tx_eff@estimate))# set-up data set.seed(1234) data <- Hainmueller$new() data$gen_data() # calculate quantities weight <- calc_weight(data, method = "Logistic", estimand = "ATE") tx_eff <- estimate_effect(causalWeights = weight) all.equal(coef(tx_eff), c(estimate = tx_eff@estimate))
Options available for the COT method
cotOptions( lambda = NULL, delta = NULL, opt.direction = c("dual", "primal"), debias = TRUE, p = 2, cost.function = NULL, cost.online = "auto", diameter = NULL, balance.formula = NULL, quick.balance.function = TRUE, grid.length = 7L, torch.optimizer = torch::optim_rmsprop, torch.scheduler = torch::lr_multiplicative, niter = 2000, nboot = 100L, lambda.bootstrap = 0.05, tol = 1e-04, device = NULL, dtype = NULL, ... )cotOptions( lambda = NULL, delta = NULL, opt.direction = c("dual", "primal"), debias = TRUE, p = 2, cost.function = NULL, cost.online = "auto", diameter = NULL, balance.formula = NULL, quick.balance.function = TRUE, grid.length = 7L, torch.optimizer = torch::optim_rmsprop, torch.scheduler = torch::lr_multiplicative, niter = 2000, nboot = 100L, lambda.bootstrap = 0.05, tol = 1e-04, device = NULL, dtype = NULL, ... )
lambda |
The penalty parameter for the entropy penalized optimal transport. Default is NULL. Can be a single number or a set of numbers to try. |
delta |
The bound for balancing functions if they are being used. Only available for biased entropy penalized optimal transport. Can be a single number or a set of numbers to try. |
opt.direction |
Should the optimizer solve the primal or dual problems. Should be one of "dual" or "primal" with a default of "dual" since it is typically faster. |
debias |
Should debiased optimal transport be used? TRUE or FALSE. |
p |
The power of the cost function to use for the cost. |
cost.function |
A function to calculate the pairwise costs. Should take arguments |
cost.online |
Should an online cost algorithm be used? One of "auto", "online", or "tensorized". "tensorized" is the offline option. |
diameter |
The diameter of the covariate space, if known. Default is NULL. |
balance.formula |
Formula for the balancing functions. |
quick.balance.function |
TRUE or FALSE denoting whether balance function constraints should be selected via a linear program (TRUE) or just checked for feasibility (FALSE). Default is TRUE. |
grid.length |
The number of penalty parameters to explore in a grid search if none are provided in arguments |
torch.optimizer |
The torch optimizer to use for methods using debiased entropy penalized optimal transport. If |
torch.scheduler |
The scheduler for the optimizer. Defaults to |
niter |
The number of iterations to run the solver |
nboot |
The number of iterations for the bootstrap to select the final penalty parameters. |
lambda.bootstrap |
The penalty parameter to use for the bootstrap hyperparameter selection of lambda. |
tol |
The tolerance for convergence |
device |
An object of class |
dtype |
An object of class |
... |
Arguments passed to the solvers. See details |
A list of class cotOptions with the following slots
lambdaThe penalty parameter for the optimal transport distance
deltaThe constraint for the balancing functions
opt.direction Whether to solve the primal or dual optimization problems
debiasTRUE or FALSE if debiased optimal transport distances are used
balance.formula The formula giving how to generate the balancing functions.
quick.balance.function TRUE or FALSE whether quick balance functions will be run.
grid.length The number of parameters to check in a grid search of best parameters
p The power of the cost function
cost.online Whether online costs are used
cost.function The user supplied cost function if supplied.
diameter The diameter of the covariate space.
torch.optimizer The torch optimizer used for Sinkhorn Divergences
torch.scheduler The scheduler for the torch optimizer
solver.options The arguments to be passeed to the torch.optimizer
scheduler.options The arguments to be passeed to the torch.scheduler
osqp.options Arguments passed to the osqp function if quick balance functions are used.
niter The number of iterations to run the solver
nboot The number of bootstrap samples
lambda.bootstrap The penalty parameter to use for the bootstrap hyperparameter selection.
tol The tolerance for convergence.
device An object of class torch_device.
dtype An object of class torch_dtype.
The function is setup to direct the COT optimizer to run two basic methods: debiased entropy penalized optimal transport (Sinkhorn Divergences) or entropy penalized optimal transport (Sinkhorn Distances).
The optimal transport problem solved is where
such that the rows of the matrix sum to and the columns sum to . In this case is the cost between units i and j.
The Sinkhorn Divergence solves
The solver for this function uses the torch package in R and by default will use the optim_rmsprop solver. Your desired torch optimizer can be passed via torch.optimizer with a scheduler passed via torch.scheduler. GPU support is available as detailed in the torch package. Additional arguments in ... are passed as extra arguments to the torch optimizer and schedulers as appropriate.
There may be certain functions of the covariates that we wish to balance within some tolerance, . For these functions , we will desire
, where in this case we are targeting balance with the treatment group for the ATT. is the pooled standard deviation prior to balancing.
The cost function specifies pairwise distances. If argument cost.function is NULL, the function will default to using distances with a default supplied by the argument p. So for p = 2, the cost between units and will be
If cost.function is provided, it should be a function that takes arguments x1, x2, and p: function(x1, x2, p){...}.
if ( torch::torch_is_installed()) { opts1 <- cotOptions(lambda = 1e3, torch.optimizer = torch::optim_rmsprop) opts2 <- cotOptions(lambda = NULL) opts3 <- cotOptions(lambda = seq(0.1, 100, length.out = 7)) }if ( torch::torch_is_installed()) { opts1 <- cotOptions(lambda = 1e3, torch.optimizer = torch::optim_rmsprop) opts2 <- cotOptions(lambda = NULL) opts3 <- cotOptions(lambda = seq(0.1, 100, length.out = 7)) }
CRASH3 data example
CRASH3 data example
Returns the CRASH3 data. Note that gen_data() will initialize the fixed data for x and y, but z is generated from Binom(0.5).
causalOT::DataSim -> CRASH3
site_idThe site of the observation in terms of the original RCT.
gen_data()
The site ID for the observations
Draws new treatment indicators. x and y data are fixed.
CRASH3$gen_data()
gen_x()
Sets up the covariate data. This data is fixed.
CRASH3$gen_x()
gen_y()
Sets up the outcome data. This data is fixed.
CRASH3$gen_y()
gen_z()
Sets up the treatment indicator. Drawn as Z ~ Binom(0.5)
CRASH3$gen_z()
new()
Initializes the CRASH3 object.
CRASH3$new(n = NULL, p = NULL, param = list(), design = NA_character_, ...)
nNot used. Maintained for symmetry with other DataSim objects.
pNot used. Maintained for symmetry with other DataSim objects.
paramNot used. Maintained for symmetry with other DataSim objects.
designNot used
...Not used.
crash <- CRASH3$new() crash$gen_data() crash$get_n() crash$site_id
clone()
The objects of this class are cloneable with this method.
CRASH3$clone(deep = FALSE)
deepWhether to make a deep clone.
## ------------------------------------------------ ## Method `CRASH3$new` ## ------------------------------------------------ crash <- CRASH3$new() crash$gen_data() crash$get_n() crash$site_id## ------------------------------------------------ ## Method `CRASH3$new` ## ------------------------------------------------ crash <- CRASH3$new() crash$gen_data() crash$get_n() crash$site_id
dataHolder
dataHolder(x, z, y = NA_real_, weights = NA_real_)dataHolder(x, z, y = NA_real_, weights = NA_real_)
x |
the covariate data. Can be a matrix, an object of class |
z |
the treatment indicator |
y |
the outcome data |
weights |
the empirical distribution of the sample |
Creates an object used internally by the causalOT package for data management.
Returns an object of class dataHolder with slots
x matrix. A matrix of confounders.
z integer. The treatment indicator, .
y numeric. The outcome data.
n0 integer. The number of observations where z==0
n1 integer. The number of observations where z==1
weights numeric. The empirical distribution of the full sample.
x <- matrix(0, 100, 10) z <- stats::rbinom(100, 1, 0.5) # don't need to provide outcome # function will assume each observation gets equal mass dataHolder(x = x, z = z)x <- matrix(0, 100, 10) z <- stats::rbinom(100, 1, 0.5) # don't need to provide outcome # function will assume each observation gets equal mass dataHolder(x = x, z = z)
R6 Data Generating Parent Class
R6 Data Generating Parent Class
Can be used to make your own data simulation class. Should have the same slots listed in this class at a minimum, but you can add your own, of course. An easy way to do this is to make your class inherit from this one. See the example.
An R6 object
get_x()
Gets the covariate data
DataSim$get_x()
get_y()
Gets the outcome vector
DataSim$get_y()
get_z()
Gets the treatment indicator
DataSim$get_z()
get_n()
Gets the number of observations
DataSim$get_n()
get_x1()
Gets the covariate data for the treated individuals
DataSim$get_x1()
get_x0()
Gets the covaraiate data for the control individuals
DataSim$get_x0()
get_p()
Gets the dimensionality covariate data
DataSim$get_p()
get_tau()
Gets the individual treatment effects
DataSim$get_tau()
gen_data()
Generates the data. Default is an empty function
DataSim$gen_data()
clone()
The objects of this class are cloneable with this method.
DataSim$clone(deep = FALSE)
deepWhether to make a deep clone.
MyClass <- R6::R6Class("MyClass", inherit = DataSim, public = list(), private = list())MyClass <- R6::R6Class("MyClass", inherit = DataSim, public = list(), private = list())
Function to turn a data.frame into a dataHolder object.
df2dataHolder( treatment.formula, outcome.formula = NA_character_, data, weights = NA_real_ )df2dataHolder( treatment.formula, outcome.formula = NA_character_, data, weights = NA_real_ )
treatment.formula |
a formula specifying the treatment indicator and covariates. Required. |
outcome.formula |
an optional formula specifying the outcome function. |
data |
a data.frame with the data |
weights |
optional vector of sampling weights for the data |
This will take the formulas specified and transform that data.frame into a dataHolder object that is used internally by the causalOT package. Take care if you do not specify an outcome formula that you do not include the outcome in the data.frame. If you are not careful, the function may include the outcome as a covariate, which is not kosher in causal inference during the design phase.
If both outcome.formula and treatment.formula are specified, it will assume you are in the design phase, and create a combined covariate matrix to balance on the assumed treatment and outcome models.
If you are in the outcome phase of estimation, you can just provide a dummy formula for the treatment.formula like "z ~ 0" just so the function can identify the treatment indicator appropriately in the data creation phase.
Returns an object of class dataHolder()
set.seed(20348) n <- 15 d <- 3 x <- matrix(stats::rnorm(n*d), n, d) z <- rbinom(n, 1, prob = 0.5) y <- rnorm(n) weights <- rep(1/n,n) df <- data.frame(x, z, y) dh <- df2dataHolder( treatment.formula = "z ~ .", outcome.formula = "y ~ ." , data = df, weights = weights)set.seed(20348) n <- 15 d <- 3 x <- matrix(stats::rnorm(n*d), n, d) z <- rbinom(n, 1, prob = 0.5) y <- rnorm(n) weights <- rep(1/n,n) df <- data.frame(x, z, y) dh <- df2dataHolder( treatment.formula = "z ~ .", outcome.formula = "y ~ ." , data = df, weights = weights)
Options for the Entropy Balancing Weights
entBWOptions(delta = NULL, grid.length = 20L, nboot = 1000L, ...)entBWOptions(delta = NULL, grid.length = 20L, nboot = 1000L, ...)
delta |
A number or vector of tolerances for the balancing functions. Default is NULL which will use a grid search |
grid.length |
The number of values to try in the grid search |
nboot |
The number of bootstrap samples to run during the grid search. |
... |
Arguments passed on to lbfgsb3c() |
A list of class entBWOptions with slots
delta Delta values to try
grid.length The number of parameters to try
nboot Number of bootstrap samples
solver.options A list of options passed to 'lbfgsb3c()
This method will balance functions of the covariates within some tolerance, . For these functions , we will desire
, where in this case we are targeting balance with the treatment group for the ATT. is the pooled standard deviation prior to balancing.
opts <- entBWOptions(delta = 0.1)opts <- entBWOptions(delta = 0.1)
Effective Sample Size
ESS(x) ## S4 method for signature 'numeric' ESS(x) ## S4 method for signature 'causalWeights' ESS(x)ESS(x) ## S4 method for signature 'numeric' ESS(x) ## S4 method for signature 'causalWeights' ESS(x)
x |
Either a vector of weights summing to 1 or an object of class causalWeights |
Calculates the effective sample size as described by Kish (1965).
However, this calculation has some problems and the PSIS()
function should be used instead.
Either a number denoting the effective sample size or if x is of class
causalWeights, then returns a list of both values in the treatment
and control groups.
ESS(numeric): default ESS method for numeric vectors
ESS(causalWeights): ESS method for objects of class causalWeights
x <- rep(1/100,100) ESS(x)x <- rep(1/100,100) ESS(x)
Estimate treatment effects
estimate_effect( causalWeights, x = NULL, y = NULL, model.function, estimate.separately = TRUE, augment.estimate = FALSE, normalize.weights = TRUE, ... )estimate_effect( causalWeights, x = NULL, y = NULL, model.function, estimate.separately = TRUE, augment.estimate = FALSE, normalize.weights = TRUE, ... )
causalWeights |
An object of class causalWeights |
x |
A dataHolder, matrix, data.frame, or object of class DataSim. See calc_weight for more details how to input the data. If |
y |
The outcome vector. |
model.function |
The modeling function to use, if desired. Must take arguments "formula", "data", and "weights". Other arguments passed via |
estimate.separately |
Should the outcome model be estimated separately in each treatment group? TRUE or FALSE. |
augment.estimate |
Should an augmented, doubly robust estimator be used? |
normalize.weights |
Should the weights in the |
... |
Pass additional arguments to the outcome modeling functions. |
an object of class causalEffect
if ( torch::torch_is_installed() ){ # set-up data data <- Hainmueller$new() data$gen_data() # calculate quantities weight <- calc_weight(data, method = "COT", estimand = "ATT", options = list(lambda = 0)) tx_eff <- estimate_effect(causalWeights = weight) # get estimate print(tx_eff@estimate) all.equal(coef(tx_eff), c(estimate = tx_eff@estimate)) }if ( torch::torch_is_installed() ){ # set-up data data <- Hainmueller$new() data$gen_data() # calculate quantities weight <- calc_weight(data, method = "COT", estimand = "ATT", options = list(lambda = 0)) tx_eff <- estimate_effect(causalWeights = weight) # get estimate print(tx_eff@estimate) all.equal(coef(tx_eff), c(estimate = tx_eff@estimate)) }
Hainmueller data example
Hainmueller data example
Generates the data as described in Hainmueller (2012).
causalOT::DataSim -> Hainmueller
gen_data()
Generates the data
Hainmueller$gen_data()
gen_x()
Generates the covaraiate data
Hainmueller$gen_x()
gen_y()
Generates the outcome data
Hainmueller$gen_y()
gen_z()
Generates the treatment indicator
Hainmueller$gen_z()
new()
Generates the the Hainmueller R6 class
Hainmueller$new( n = 100, p = 6, param = list(), design = "A", overlap = "low", ... )
nThe number of observations
pThe dimensions of the covariates. Fixed to 6.
paramThe data generating parameters fed as a list.
designOne of "A" or "B". See details.
overlapOne of "high", "low", or "medium". See details.
...Extra arguments. Currently unused.
Design "A"
is the setting where the outcome is generated
from a linear model,
and design "B" is where the outcome is
generated from the non-linear model
.
The treatment indicator is generated from
, where
depends on the overlap selected. If overlap is "high",
then If overlap is
"low", then Finally,
if overlap is "medium", then is drawn
from a with 5 degrees of freedom
that is scaled and centered to have mean 0.5 and
variance 67.6.
An object of class DataSim.
data <- Hainmueller$new(n = 100, p = 6, design = "A", overlap = "low") data$gen_data() print(data$get_x()[1:2,])
get_design()
Returns the chosen design parameters
Hainmueller$get_design()
get_pscore()
Returns the true propensity score
Hainmueller$get_pscore()
clone()
The objects of this class are cloneable with this method.
Hainmueller$clone(deep = FALSE)
deepWhether to make a deep clone.
## ------------------------------------------------ ## Method `Hainmueller$new` ## ------------------------------------------------ data <- Hainmueller$new(n = 100, p = 6, design = "A", overlap = "low") data$gen_data() print(data$get_x()[1:2,])## ------------------------------------------------ ## Method `Hainmueller$new` ## ------------------------------------------------ data <- Hainmueller$new(n = 100, p = 6, design = "A", overlap = "low") data$gen_data() print(data$get_x()[1:2,])
LaLonde data example
LaLonde data example
Returns the LaLonde data as used by Dehjia and Wahba. Note the data
is fixed and gen_data() will just initialize the fixed data.
causalOT::DataSim -> LaLonde
gen_data()
Sets up the data
LaLonde$gen_data()
get_tau()
Returns the experimental treatment effect, $1794
LaLonde$get_tau()
gen_x()
Sets up the covariate data
LaLonde$gen_x()
gen_y()
Sets up the outcome data
LaLonde$gen_y()
gen_z()
Sets up the treatment indicator
LaLonde$gen_z()
new()
Initializes the LaLonde object.
LaLonde$new(n = NULL, p = NULL, param = list(), design = "NSW", ...)
nNot used. Maintained for symmetry with other DataSim objects.
pNot used. Maintained for symmetry with other DataSim objects.
paramNot used. Maintained for symmetry with other DataSim objects.
designOne of "NSW" or "Full". "NSW" uses the original experimental data from the job training program while option "Full" uses the treated individuals from LaLonde's study and compares them to individuals from the Current Population Survey as controls.
...Not used.
nsw <- LaLonde$new(design = "NSW") nsw$gen_data() nsw$get_n() obs.study <- LaLonde$new(design = "Full") obs.study$gen_data() obs.study$get_n()
get_design()
Returns the chosen design parameters
LaLonde$get_design()
clone()
The objects of this class are cloneable with this method.
LaLonde$clone(deep = FALSE)
deepWhether to make a deep clone.
## ------------------------------------------------ ## Method `LaLonde$new` ## ------------------------------------------------ nsw <- LaLonde$new(design = "NSW") nsw$gen_data() nsw$get_n() obs.study <- LaLonde$new(design = "Full") obs.study$gen_data() obs.study$get_n()## ------------------------------------------------ ## Method `LaLonde$new` ## ------------------------------------------------ nsw <- LaLonde$new(design = "NSW") nsw$gen_data() nsw$get_n() obs.study <- LaLonde$new(design = "Full") obs.study$gen_data() obs.study$get_n()
This function will calculate the difference in means between treatment groups standardized by the pooled standard-deviation of the respective covariates.
mean_balance(x = NULL, z = NULL, weights = NULL, ...)mean_balance(x = NULL, z = NULL, weights = NULL, ...)
x |
Either a matrix, an object of class dataHolder, or an object of class DataSim |
z |
A integer vector denoting the treatments of each observations. Can be null if |
weights |
An object of class causalWeights. |
... |
Not used at this time. |
A vector of mean balances
n <- 100 p <- 6 x <- matrix(stats::rnorm(n * p), n, p) z <- stats::rbinom(n, 1, 0.5) weights <- calc_weight(x = x, z = z, estimand = "ATT", method = "Logistic") mb <- mean_balance(x = x, z = z, weights = weights) print(mb)n <- 100 p <- 6 x <- matrix(stats::rnorm(n * p), n, p) z <- stats::rbinom(n, 1, 0.5) weights <- calc_weight(x = x, z = z, estimand = "ATT", method = "Logistic") mb <- mean_balance(x = x, z = z, weights = weights) print(mb)
Constructor for an R6 Measure object.
Measure( x, weights = NULL, probability.measure = TRUE, adapt = c("none", "weights", "x"), balance.functions = NA_real_, target.values = NA_real_, dtype = NULL, device = NULL )Measure( x, weights = NULL, probability.measure = TRUE, adapt = c("none", "weights", "x"), balance.functions = NA_real_, target.values = NA_real_, dtype = NULL, device = NULL )
x |
The data points |
weights |
The empirical measure. If NULL, assigns equal weight to each observation |
probability.measure |
Is the empirical measure a probability measure? Default is TRUE. |
adapt |
Should we try to adapt the data ("x"), the weights ("weights"), or neither ("none"). Default is "none". |
balance.functions |
A matrix of functions of the covariates to target for mean balance. If NULL and |
target.values |
The targets for the balance functions. Should be the same length as columns in |
dtype |
The torch_tensor dtype or NULL. |
device |
The device to have the data on. Should be result of |
An R6 class for representing empirical measures (data + weights) with optional gradient-based adaptation via torch.
Use Measure() to construct a measure. The returned object supports
active bindings like $weights and $x, and methods like $detach(). See below for defined methods and fields.
Returns a Measure object
balance_functionsthe functions of the data that we want to adjust towards the targets
balance_targetthe values the balance_functions are targeting
adaptWhat aspect of the data will be adapted. One of "none","weights", or "x".
devicethe torch::torch_device() of the data.
dtypethe torch::torch_dtype of the data.
nthe rows of the covariates, x.
dthe columns of the covariates, x.
probability_measureis the measure a probability measure?
gradgets or sets gradient
init_weightsreturns the initial value of the weights
init_datareturns the initial value of the data
requires_gradchecks or turns on/off gradient
weightsgets or sets weights
xGets or sets the data.
detach()
generates a deep clone of the object without gradients.
Measure_$detach()
get_weight_parameters()
Makes a copy of the weights parameters. prints the measure object
Measure_$get_weight_parameters()
print()
Measure_$print(...)
...Not used Constructor function
new()
Measure_$new(
x,
weights = NULL,
probability.measure = TRUE,
adapt = c("none", "weights", "x"),
balance.functions = NA_real_,
target.values = NA_real_,
dtype = NULL,
device = NULL
)xThe data points
weightsThe empirical measure. If NULL, assigns equal weight to each observation
probability.measureIs the empirical measure a probability measure? Default is TRUE.
adaptShould we try to adapt the data ("x"), the weights ("weights"), or neither ("none"). Default is "none".
balance.functionsA matrix of functions of the covariates to target for mean balance. If NULL and target.values are provided, will use the data in x.
target.valuesThe targets for the balance functions. Should be the same length as columns in balance.functions.
dtypeThe torch::torch_dtype or NULL.
deviceThe device to have the data on. Should be result of torch::torch_device() or NULL.
clone()
The objects of this class are cloneable with this method.
Measure_$clone(deep = FALSE)
deepWhether to make a deep clone.
if(torch::torch_is_installed()) { m <- Measure(x = matrix(0, 10, 2), adapt = "none", device = torch::torch_device("cpu"), dtype = torch::torch_double()) print(m) m$x m$x <- matrix(1,10,2) # must have same dimensions m$x m$weights m$weights <- 1:10/sum(1:10) m$weights # with gradients m <- Measure(x = matrix(0, 10, 2), adapt = "weights", device = torch::torch_device("cpu"), dtype = torch::torch_double()) m$requires_grad # TRUE m$requires_grad <- "none" # turns off m$requires_grad # FALSE m$requires_grad <- "x" m$requires_grad # TRUE m <- Measure(matrix(0, 10, 2), adapt = "none", device = torch::torch_device("cpu"), dtype = torch::torch_double()) m$grad # NULL m <- Measure(matrix(0, 10, 2), adapt = "weights", device = torch::torch_device("cpu"), dtype = torch::torch_double()) loss <- sum(m$weights * 1:10) loss$backward() m$grad # note the weights gradient is on the log softmax scale #and the first parameter is fixed for identifiability m$grad <- rep(1,9) m$grad }if(torch::torch_is_installed()) { m <- Measure(x = matrix(0, 10, 2), adapt = "none", device = torch::torch_device("cpu"), dtype = torch::torch_double()) print(m) m$x m$x <- matrix(1,10,2) # must have same dimensions m$x m$weights m$weights <- 1:10/sum(1:10) m$weights # with gradients m <- Measure(x = matrix(0, 10, 2), adapt = "weights", device = torch::torch_device("cpu"), dtype = torch::torch_double()) m$requires_grad # TRUE m$requires_grad <- "none" # turns off m$requires_grad # FALSE m$requires_grad <- "x" m$requires_grad # TRUE m <- Measure(matrix(0, 10, 2), adapt = "none", device = torch::torch_device("cpu"), dtype = torch::torch_double()) m$grad # NULL m <- Measure(matrix(0, 10, 2), adapt = "weights", device = torch::torch_device("cpu"), dtype = torch::torch_double()) loss <- sum(m$weights * 1:10) loss$backward() m$grad # note the weights gradient is on the log softmax scale #and the first parameter is fixed for identifiability m$grad <- rep(1,9) m$grad }
Optimal Transport Distance
ot_distance( x1, x2 = NULL, a = NULL, b = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07 ) ## S3 method for class 'causalWeights' ot_distance( x1, x2 = NULL, a = NULL, b = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07 ) ## S3 method for class 'matrix' ot_distance( x1, x2, a = NULL, b = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07 ) ## S3 method for class 'array' ot_distance( x1, x2, a = NULL, b = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07 ) ## S3 method for class 'torch_tensor' ot_distance( x1, x2, a = NULL, b = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07 )ot_distance( x1, x2 = NULL, a = NULL, b = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07 ) ## S3 method for class 'causalWeights' ot_distance( x1, x2 = NULL, a = NULL, b = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07 ) ## S3 method for class 'matrix' ot_distance( x1, x2, a = NULL, b = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07 ) ## S3 method for class 'array' ot_distance( x1, x2, a = NULL, b = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07 ) ## S3 method for class 'torch_tensor' ot_distance( x1, x2, a = NULL, b = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07 )
x1 |
Either an object of class causalWeights or a matrix of the covariates in the first sample |
x2 |
|
a |
Empirical measure of the first sample. If NULL, assumes each observation gets equal mass. Ignored for objects of class causalWeights. |
b |
Empirical measure of the second sample. If NULL, assumes each observation gets equal mass. Ignored for objects of class causalWeights. |
penalty |
The penalty of the optimal transport distance to use. If missing or NULL, the function will try to guess a suitable value depending if debias is TRUE or FALSE. |
p |
|
cost |
Supply your own cost function. Should take arguments |
debias |
TRUE or FALSE. Should the debiased optimal transport distances be used. |
online.cost |
How to calculate the distance matrix. One of "auto", "tensorized", or "online". |
diameter |
The diameter of the metric space, if known. Default is NULL. |
niter |
The maximum number of iterations for the Sinkhorn updates |
tol |
The tolerance for convergence |
For objects of class matrix, numeric value giving the optimal transport distance. For objects of class causalWeights, results are returned as a list for before ('pre') and after adjustment ('post').
ot_distance(causalWeights): method for causalWeights class
ot_distance(matrix): method for matrices
ot_distance(array): method for arrays
ot_distance(torch_tensor): method for torch_tensors
if ( torch::torch_is_installed()) { x <- matrix(stats::rnorm(10*5), 10, 5) z <- stats::rbinom(10, 1, 0.5) weights <- calc_weight(x = x, z = z, method = "Logistic", estimand = "ATT") ot1 <- ot_distance(x1 = weights, penalty = 100, p = 2, debias = TRUE, online.cost = "auto", diameter = NULL) ot2<- ot_distance(x1 = x[z==0, ], x2 = x[z == 1,], a= weights@w0/sum(weights@w0), b = weights@w1, penalty = 100, p = 2, debias = TRUE, online.cost = "auto", diameter = NULL) all.equal(ot1$post, ot2) }if ( torch::torch_is_installed()) { x <- matrix(stats::rnorm(10*5), 10, 5) z <- stats::rbinom(10, 1, 0.5) weights <- calc_weight(x = x, z = z, method = "Logistic", estimand = "ATT") ot1 <- ot_distance(x1 = weights, penalty = 100, p = 2, debias = TRUE, online.cost = "auto", diameter = NULL) ot2<- ot_distance(x1 = x[z==0, ], x2 = x[z == 1,], a= weights@w0/sum(weights@w0), b = weights@w1, penalty = 100, p = 2, debias = TRUE, online.cost = "auto", diameter = NULL) all.equal(ot1$post, ot2) }
User-facing constructor for an R6 OTProblem object.
OTProblem(measure_1, measure_2, ...)OTProblem(measure_1, measure_2, ...)
measure_1 |
An object of class Measure |
measure_2 |
An object of class Measure |
... |
Not used at this time |
An R6 class for creating optimal transport problems with two Measure objects.
Use OTProblem() to construct an object of class
OTProblem. The component objects must be of class Measure.
The process of solving an OT problem involves three steps:
(1) setting up the problem by creating Measure objects and combining them into an OTProblem object,
(2) choosing the hyperparameters for the problem, and
(3) solving the problem by minimizing the objective function.
The first step is done by creating Measure objects and
then combining them into an OTProblem object using
the $add(), $subtract(), $multiply(), and $divide() methods.
The second step is done by calling the $setup_arguments() method on the OTProblem object.
The third step is done by calling the $solve() method on the OTProblem object.
An R6 object of class OTProblem.
devicethe torch::torch_device() of the data.
dtypethe torch::torch_dtype of the data.
selected_deltathe delta value selected after choose_hyperparameters
selected_lambdathe lambda value selected after choose_hyperparameters
lossPrints the current value of the objective. Only available after the solve method has been run
penaltyReturns a list of the lambda and delta penalities that will be iterated through. To set these values, use the setup_arguments function.
add()
adds o2 to the OTProblem
OTProblem_$add(o2)
o2A number or object of class OTProblem
# example code
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x,
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
y <- matrix(2, 100, 10)
m2 <- Measure(x = y,
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
z <- matrix(3,102, 10)
m3 <- Measure(x = z,
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
# setup OT problems
ot1 <- OTProblem(m1, m2)
ot2 <- OTProblem(m3, m2)
print(ot1)
print(ot2)
ot1$add(ot2)
print(ot1)
print(ot2)
}
subtract()
subtracts o2 from OTProblem
OTProblem_$subtract(o2)
o2A number or object of class OTProblem
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x,
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
y <- matrix(2, 100, 10)
m2 <- Measure(x = y,
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
z <- matrix(3,102, 10)
m3 <- Measure(x = z,
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
# setup OT problems
ot1 <- OTProblem(m1, m2)
ot2 <- OTProblem(m3, m2)
print(ot1)
print(ot2)
ot1$subtract(ot2)
print(ot1)
print(ot2)
}
multiply()
multiplies OTProblem by o2
OTProblem_$multiply(o2)
o2A number or object of class OTProblem
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x,
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
y <- matrix(2, 100, 10)
m2 <- Measure(x = y,
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
z <- matrix(3,102, 10)
m3 <- Measure(x = z,
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
# setup OT problems
ot1 <- OTProblem(m1, m2)
ot2 <- OTProblem(m3, m2)
print(ot1)
print(ot2)
ot1$multiply(ot2)
print(ot1)
print(ot2)
}
divide()
divides OTProblem by agument o2
OTProblem_$divide(o2)
o2A number or object of class OTProblem
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x,
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
y <- matrix(2, 100, 10)
m2 <- Measure(x = y,
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
z <- matrix(3,102, 10)
m3 <- Measure(x = z,
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
# setup OT problems
ot1 <- OTProblem(m1, m2)
ot2 <- OTProblem(m3, m2)
print(ot1)
print(ot2)
ot1$divide(ot2)
print(ot1)
print(ot2)
}
print()
prints the OT problem object
OTProblem_$print(...)
...Not used at this time
new()
Constructor method
OTProblem_$new(measure_1, measure_2)
An R6 object of class OTProblem
setup_arguments()
Sets up the OT problems for the OTProblem object. This should be run before choose_hyperparameters and solve.
OTProblem_$setup_arguments( lambda, delta, grid.length = 7L, cost.function = NULL, p = 2, cost.online = "auto", debias = TRUE, diameter = NULL, ot_niter = 1000L, ot_tol = 0.001 )
lambdaThe penalty parameters to try for the OTProblem. If not provided, the function will select some.
deltaThe constraint paramters to try for the balance function problems, if any.
grid.lengthThe number of hyperparameters to try if not provided
cost.functionThe cost function for the data. Can be any function that takes arguments x1, x2, p. Defaults to the Euclidean distance.
pThe power to raise the cost matrix by. Default is 2
cost.onlineShould online costs be used? Default is "auto" but "tensorized" stores the cost matrix in memory while "online" will calculate it on the fly.
debiasShould debiased a debiased OTProblem be used? Defaults to TRUE
diameterDiameter of the cost function.
ot_niterNumber of iterations to run the solver
ot_tolThe tolerance for convergence of the objective function
returns the object invisibly
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x,
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
y <- matrix(2, 100, 10)
m2 <- Measure(x = y, adapt = "weights",
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
ot <- OTProblem(m1, m2)
ot$setup_arguments(lambda = 1000)
}
solve()
Solve the OTProblem at each parameter value. Must run setup_arguments first.
OTProblem_$solve(
niter = 1000L,
tol = 1e-05,
optimizer = c("torch", "frank-wolfe"),
torch_optim = torch::optim_lbfgs,
torch_scheduler = torch::lr_reduce_on_plateau,
torch_args = NULL,
osqp_args = NULL,
quick.balance.function = TRUE
)niterThe nubmer of iterations to run solver at each combination of hyperparameter values
tolThe tolerance for convergence
optimizerThe optimizer to use. One of "torch" or "frank-wolfe"
torch_optimThe torch_optimizer to use. Default is torch::optim_lbfgs
torch_schedulerThe torch::lr_scheduler to use. Default is torch::lr_reduce_on_plateau
torch_argsArguments passed to the torch optimizer and scheduler
osqp_argsArguments passed to osqp::osqpSettings() if appropriate
quick.balance.functionShould osqp::osqp() be used to select balance function constraints (delta) or not. Default true.
returns the object invisibly
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x,
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
y <- matrix(2, 100, 10)
m2 <- Measure(x = y, adapt = "weights",
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
ot <- OTProblem(m1, m2)
ot$setup_arguments(lambda = 1000)
ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
}
choose_hyperparameters()
Selects the hyperparameter values through a bootstrap algorithm
OTProblem_$choose_hyperparameters( n_boot_lambda = 100L, n_boot_delta = 1000L, lambda_bootstrap = Inf )
n_boot_lambdaThe number of bootstrap iterations to run when selecting lambda
n_boot_deltaThe number of bootstrap iterations to run when selecting delta
lambda_bootstrapThe penalty parameter to use when selecting lambda. Higher numbers run faster.
returns the object invisibly
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x,
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
y <- matrix(2, 100, 10)
m2 <- Measure(x = y, adapt = "weights",
device = torch::torch_device("cpu"),
dtype = torch::torch_double())
ot <- OTProblem(m1, m2)
ot$setup_arguments(lambda = c(1,1000))
ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
ot$choose_hyperparameters(n_boot_lambda = 2, n_boot_delta = 10, lambda_bootstrap = 100)
}
info()
Provides diagnostics after solve and choose_hyperparameter methods have been run.
OTProblem_$info()
a list with slots
loss the final loss values
iterations The number of iterations run for each combination of parameters
balance.function.differences The final differences in the balance functions
hyperparam.metrics A list of the bootstrap evaluation for delta and lambda values
if (torch::torch_is_installed()) {
ot$info()
}
clone()
The objects of this class are cloneable with this method.
OTProblem_$clone(deep = FALSE)
deepWhether to make a deep clone.
if (torch::torch_is_installed()) { # setup measures x <- matrix(1, 100, 10) m1 <- Measure(x = x, device = torch::torch_device("cpu"), dtype = torch::torch_double()) y <- matrix(2, 100, 10) m2 <- Measure(x = y, adapt = "weights", device = torch::torch_device("cpu"), dtype = torch::torch_double()) z <- matrix(3,102, 10) m3 <- Measure(x = z, device = torch::torch_device("cpu"), dtype = torch::torch_double()) # setup OT problems ot1 <- OTProblem(m1, m2, device = torch::torch_device("cpu"), dtype = torch::torch_double()) ot2 <- OTProblem(m3, m2, device = torch::torch_device("cpu"), dtype = torch::torch_double()) # you can add or subtract OTProblem objects into # a new OTProblem ot <- 0.5 * ot1 + 0.5 * ot2 print(ot) # Then you choose the hyperparameters ot$setup_arguments(lambda = 1000) # then you can solve the objective function ot$solve(niter = 1, torch_optim = torch::optim_rmsprop) } ## ------------------------------------------------ ## Method `OTProblem_$add` ## ------------------------------------------------ # example code if (torch::torch_is_installed()) { # setup measures x <- matrix(1, 100, 10) m1 <- Measure(x = x, device = torch::torch_device("cpu"), dtype = torch::torch_double()) y <- matrix(2, 100, 10) m2 <- Measure(x = y, device = torch::torch_device("cpu"), dtype = torch::torch_double()) z <- matrix(3,102, 10) m3 <- Measure(x = z, device = torch::torch_device("cpu"), dtype = torch::torch_double()) # setup OT problems ot1 <- OTProblem(m1, m2) ot2 <- OTProblem(m3, m2) print(ot1) print(ot2) ot1$add(ot2) print(ot1) print(ot2) } ## ------------------------------------------------ ## Method `OTProblem_$subtract` ## ------------------------------------------------ if (torch::torch_is_installed()) { # setup measures x <- matrix(1, 100, 10) m1 <- Measure(x = x, device = torch::torch_device("cpu"), dtype = torch::torch_double()) y <- matrix(2, 100, 10) m2 <- Measure(x = y, device = torch::torch_device("cpu"), dtype = torch::torch_double()) z <- matrix(3,102, 10) m3 <- Measure(x = z, device = torch::torch_device("cpu"), dtype = torch::torch_double()) # setup OT problems ot1 <- OTProblem(m1, m2) ot2 <- OTProblem(m3, m2) print(ot1) print(ot2) ot1$subtract(ot2) print(ot1) print(ot2) } ## ------------------------------------------------ ## Method `OTProblem_$multiply` ## ------------------------------------------------ if (torch::torch_is_installed()) { # setup measures x <- matrix(1, 100, 10) m1 <- Measure(x = x, device = torch::torch_device("cpu"), dtype = torch::torch_double()) y <- matrix(2, 100, 10) m2 <- Measure(x = y, device = torch::torch_device("cpu"), dtype = torch::torch_double()) z <- matrix(3,102, 10) m3 <- Measure(x = z, device = torch::torch_device("cpu"), dtype = torch::torch_double()) # setup OT problems ot1 <- OTProblem(m1, m2) ot2 <- OTProblem(m3, m2) print(ot1) print(ot2) ot1$multiply(ot2) print(ot1) print(ot2) } ## ------------------------------------------------ ## Method `OTProblem_$divide` ## ------------------------------------------------ if (torch::torch_is_installed()) { # setup measures x <- matrix(1, 100, 10) m1 <- Measure(x = x, device = torch::torch_device("cpu"), dtype = torch::torch_double()) y <- matrix(2, 100, 10) m2 <- Measure(x = y, device = torch::torch_device("cpu"), dtype = torch::torch_double()) z <- matrix(3,102, 10) m3 <- Measure(x = z, device = torch::torch_device("cpu"), dtype = torch::torch_double()) # setup OT problems ot1 <- OTProblem(m1, m2) ot2 <- OTProblem(m3, m2) print(ot1) print(ot2) ot1$divide(ot2) print(ot1) print(ot2) } ## ------------------------------------------------ ## Method `OTProblem_$setup_arguments` ## ------------------------------------------------ if (torch::torch_is_installed()) { # setup measures x <- matrix(1, 100, 10) m1 <- Measure(x = x, device = torch::torch_device("cpu"), dtype = torch::torch_double()) y <- matrix(2, 100, 10) m2 <- Measure(x = y, adapt = "weights", device = torch::torch_device("cpu"), dtype = torch::torch_double()) ot <- OTProblem(m1, m2) ot$setup_arguments(lambda = 1000) } ## ------------------------------------------------ ## Method `OTProblem_$solve` ## ------------------------------------------------ if (torch::torch_is_installed()) { # setup measures x <- matrix(1, 100, 10) m1 <- Measure(x = x, device = torch::torch_device("cpu"), dtype = torch::torch_double()) y <- matrix(2, 100, 10) m2 <- Measure(x = y, adapt = "weights", device = torch::torch_device("cpu"), dtype = torch::torch_double()) ot <- OTProblem(m1, m2) ot$setup_arguments(lambda = 1000) ot$solve(niter = 1, torch_optim = torch::optim_rmsprop) } ## ------------------------------------------------ ## Method `OTProblem_$choose_hyperparameters` ## ------------------------------------------------ if (torch::torch_is_installed()) { # setup measures x <- matrix(1, 100, 10) m1 <- Measure(x = x, device = torch::torch_device("cpu"), dtype = torch::torch_double()) y <- matrix(2, 100, 10) m2 <- Measure(x = y, adapt = "weights", device = torch::torch_device("cpu"), dtype = torch::torch_double()) ot <- OTProblem(m1, m2) ot$setup_arguments(lambda = c(1,1000)) ot$solve(niter = 1, torch_optim = torch::optim_rmsprop) ot$choose_hyperparameters(n_boot_lambda = 2, n_boot_delta = 10, lambda_bootstrap = 100) } ## ------------------------------------------------ ## Method `OTProblem_$info` ## ------------------------------------------------ if (torch::torch_is_installed()) { ot$info() }if (torch::torch_is_installed()) { # setup measures x <- matrix(1, 100, 10) m1 <- Measure(x = x, device = torch::torch_device("cpu"), dtype = torch::torch_double()) y <- matrix(2, 100, 10) m2 <- Measure(x = y, adapt = "weights", device = torch::torch_device("cpu"), dtype = torch::torch_double()) z <- matrix(3,102, 10) m3 <- Measure(x = z, device = torch::torch_device("cpu"), dtype = torch::torch_double()) # setup OT problems ot1 <- OTProblem(m1, m2, device = torch::torch_device("cpu"), dtype = torch::torch_double()) ot2 <- OTProblem(m3, m2, device = torch::torch_device("cpu"), dtype = torch::torch_double()) # you can add or subtract OTProblem objects into # a new OTProblem ot <- 0.5 * ot1 + 0.5 * ot2 print(ot) # Then you choose the hyperparameters ot$setup_arguments(lambda = 1000) # then you can solve the objective function ot$solve(niter = 1, torch_optim = torch::optim_rmsprop) } ## ------------------------------------------------ ## Method `OTProblem_$add` ## ------------------------------------------------ # example code if (torch::torch_is_installed()) { # setup measures x <- matrix(1, 100, 10) m1 <- Measure(x = x, device = torch::torch_device("cpu"), dtype = torch::torch_double()) y <- matrix(2, 100, 10) m2 <- Measure(x = y, device = torch::torch_device("cpu"), dtype = torch::torch_double()) z <- matrix(3,102, 10) m3 <- Measure(x = z, device = torch::torch_device("cpu"), dtype = torch::torch_double()) # setup OT problems ot1 <- OTProblem(m1, m2) ot2 <- OTProblem(m3, m2) print(ot1) print(ot2) ot1$add(ot2) print(ot1) print(ot2) } ## ------------------------------------------------ ## Method `OTProblem_$subtract` ## ------------------------------------------------ if (torch::torch_is_installed()) { # setup measures x <- matrix(1, 100, 10) m1 <- Measure(x = x, device = torch::torch_device("cpu"), dtype = torch::torch_double()) y <- matrix(2, 100, 10) m2 <- Measure(x = y, device = torch::torch_device("cpu"), dtype = torch::torch_double()) z <- matrix(3,102, 10) m3 <- Measure(x = z, device = torch::torch_device("cpu"), dtype = torch::torch_double()) # setup OT problems ot1 <- OTProblem(m1, m2) ot2 <- OTProblem(m3, m2) print(ot1) print(ot2) ot1$subtract(ot2) print(ot1) print(ot2) } ## ------------------------------------------------ ## Method `OTProblem_$multiply` ## ------------------------------------------------ if (torch::torch_is_installed()) { # setup measures x <- matrix(1, 100, 10) m1 <- Measure(x = x, device = torch::torch_device("cpu"), dtype = torch::torch_double()) y <- matrix(2, 100, 10) m2 <- Measure(x = y, device = torch::torch_device("cpu"), dtype = torch::torch_double()) z <- matrix(3,102, 10) m3 <- Measure(x = z, device = torch::torch_device("cpu"), dtype = torch::torch_double()) # setup OT problems ot1 <- OTProblem(m1, m2) ot2 <- OTProblem(m3, m2) print(ot1) print(ot2) ot1$multiply(ot2) print(ot1) print(ot2) } ## ------------------------------------------------ ## Method `OTProblem_$divide` ## ------------------------------------------------ if (torch::torch_is_installed()) { # setup measures x <- matrix(1, 100, 10) m1 <- Measure(x = x, device = torch::torch_device("cpu"), dtype = torch::torch_double()) y <- matrix(2, 100, 10) m2 <- Measure(x = y, device = torch::torch_device("cpu"), dtype = torch::torch_double()) z <- matrix(3,102, 10) m3 <- Measure(x = z, device = torch::torch_device("cpu"), dtype = torch::torch_double()) # setup OT problems ot1 <- OTProblem(m1, m2) ot2 <- OTProblem(m3, m2) print(ot1) print(ot2) ot1$divide(ot2) print(ot1) print(ot2) } ## ------------------------------------------------ ## Method `OTProblem_$setup_arguments` ## ------------------------------------------------ if (torch::torch_is_installed()) { # setup measures x <- matrix(1, 100, 10) m1 <- Measure(x = x, device = torch::torch_device("cpu"), dtype = torch::torch_double()) y <- matrix(2, 100, 10) m2 <- Measure(x = y, adapt = "weights", device = torch::torch_device("cpu"), dtype = torch::torch_double()) ot <- OTProblem(m1, m2) ot$setup_arguments(lambda = 1000) } ## ------------------------------------------------ ## Method `OTProblem_$solve` ## ------------------------------------------------ if (torch::torch_is_installed()) { # setup measures x <- matrix(1, 100, 10) m1 <- Measure(x = x, device = torch::torch_device("cpu"), dtype = torch::torch_double()) y <- matrix(2, 100, 10) m2 <- Measure(x = y, adapt = "weights", device = torch::torch_device("cpu"), dtype = torch::torch_double()) ot <- OTProblem(m1, m2) ot$setup_arguments(lambda = 1000) ot$solve(niter = 1, torch_optim = torch::optim_rmsprop) } ## ------------------------------------------------ ## Method `OTProblem_$choose_hyperparameters` ## ------------------------------------------------ if (torch::torch_is_installed()) { # setup measures x <- matrix(1, 100, 10) m1 <- Measure(x = x, device = torch::torch_device("cpu"), dtype = torch::torch_double()) y <- matrix(2, 100, 10) m2 <- Measure(x = y, adapt = "weights", device = torch::torch_device("cpu"), dtype = torch::torch_double()) ot <- OTProblem(m1, m2) ot$setup_arguments(lambda = c(1,1000)) ot$solve(niter = 1, torch_optim = torch::optim_rmsprop) ot$choose_hyperparameters(n_boot_lambda = 2, n_boot_delta = 10, lambda_bootstrap = 100) } ## ------------------------------------------------ ## Method `OTProblem_$info` ## ------------------------------------------------ if (torch::torch_is_installed()) { ot$info() }
plot.causalWeights
## S3 method for class 'causalWeights' plot( x, r_eff = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07, ... )## S3 method for class 'causalWeights' plot( x, r_eff = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07, ... )
x |
A causalWeights object |
r_eff |
The |
penalty |
The penalty of the optimal transport distance to use. If missing or NULL, the function will try to guess a suitable value depending if debias is TRUE or FALSE. |
p |
|
cost |
Supply your own cost function. Should take arguments |
debias |
TRUE or FALSE. Should the debiased optimal transport distances be used. |
online.cost |
How to calculate the distance matrix. One of "auto", "tensorized", or "online". |
diameter |
The diameter of the metric space, if known. Default is NULL. |
niter |
The maximum number of iterations for the Sinkhorn updates |
tol |
The tolerance for convergence |
... |
Not used at this time |
The plot method first calls summary.causalWeights on the causalWeights object. Then plots the diagnostics from that summary object.
The plot method returns an invisible object of class summary_causalWeights.
A dataset evaluating treatments for post-partum hemorrhage. The data contain treatment groups receiving misoprostol vs potential controls from other locations that received only oxytocin. The data is stored as a numeric matrix.
data(pph)data(pph)
A matrix with 802 rows and 17 variables
The variables are as follows:
cum_blood_20m. The outcome variable denoting cumulative blood loss in mL 20 minutes after the diagnosis of post-partum hemorrhage (650 – 2000).
tx. The treatment indicator of whether an individual received misoprostol (1) or oxytocin (0).
age. the mother's age in years (15 – 43).
no_educ. whether a woman had no education (1) or some education (0).
num_livebirth. the number of previous live births.
cur_married. whether a mother is currently married (1 = yes, 0 = no).
gest_age. the gestational age of the fetus in weeks (35 – 43).
prev_pphyes. whether the woman has had a previous post-partum hemorrahge.
hb_test. the woman's hemoglobin in mg/dL (7 – 15).
induced_laboryes. whether labor was induced (1 = yes, 0 = no).
augmented_laboryes. whether labor was augmented (1 = yes, 0 = no).
early_cordclampyes. whether the umbilical cord was clamped early (1 = yes, 0 = no).
control_cordtractionyes. whether cord traction was controlled (1 = yes, 0 = no).
uterine_massageyes. whether a uterine massage was given (1 = yes, 0 = no).
placenta. whether placenta was delivered before treatment given (1 = yes, 0 = no).
bloodlossattx. amount of blood lost when treatment given (500 mL – 1800 mL)
sitecode. Which site is the individual from? (1 = Cairo, Egypt, 2 = Turkey, 3 = Hocmon, Vietnam, 4 = Cuchi, Vietnam, and 5 Burkina Faso).
Data from the following Harvard Dataverse:
Winikoff, Beverly, 2019, "Two randomized controlled trials of misoprostol for the treatment of postpartum hemorrhage", https://doi.org/10.7910/DVN/ETHH4N, Harvard Dataverse, V1.
The data was originally analyzed in
Blum, J. et al. Treatment of post-partum haemorrhage with sublingual misoprostol versus oxytocin in women receiving prophylactic oxytocin: a double-blind, randomised, non-inferiority trial. The Lancet 375, 217–223 (2010).
Predict method for barycentric projection models
## S3 method for class 'bp' predict( object, newdata = NULL, source.sample, cost_function = NULL, niter = 1000, tol = 1e-07, ... )## S3 method for class 'bp' predict( object, newdata = NULL, source.sample, cost_function = NULL, niter = 1000, tol = 1e-07, ... )
object |
An object of class "bp" |
newdata |
a data.frame containing new observations |
source.sample |
a vector giving the sample each observations arise from |
cost_function |
a cost metric between observations |
niter |
number of iterations to run the barycentric projection for powers > 2. |
tol |
Tolerance on the optimization problem for projections with powers > 2. |
... |
Dots passed to the lbfgs method in the torch package. |
if(torch::torch_is_installed()) { set.seed(23483) n <- 2^5 pp <- 6 overlap <- "low" design <- "A" estimate <- "ATT" power <- 2 data <- causalOT::Hainmueller$new(n = n, p = pp, design = design, overlap = overlap) data$gen_data() weights <- causalOT::calc_weight(x = data, z = NULL, y = NULL, estimand = estimate, method = "NNM") df <- data.frame(y = data$get_y(), z = data$get_z(), data$get_x()) # undebiased fit <- causalOT::barycentric_projection(y ~ ., data = df, weight = weights, separate.samples.on = "z", niter = 2) #debiased fit_d <- causalOT::barycentric_projection(y ~ ., data = df, weight = weights, separate.samples.on = "z", debias = TRUE, niter = 2) # predictions, without new data undebiased_predictions <- predict(fit, source.sample = df$z) debiased_predictions <- predict(fit_d, source.sample = df$z) isTRUE(all.equal(unname(undebiased_predictions), df$y)) # FALSE isTRUE(all.equal(unname(debiased_predictions), df$y)) # TRUE }if(torch::torch_is_installed()) { set.seed(23483) n <- 2^5 pp <- 6 overlap <- "low" design <- "A" estimate <- "ATT" power <- 2 data <- causalOT::Hainmueller$new(n = n, p = pp, design = design, overlap = overlap) data$gen_data() weights <- causalOT::calc_weight(x = data, z = NULL, y = NULL, estimand = estimate, method = "NNM") df <- data.frame(y = data$get_y(), z = data$get_z(), data$get_x()) # undebiased fit <- causalOT::barycentric_projection(y ~ ., data = df, weight = weights, separate.samples.on = "z", niter = 2) #debiased fit_d <- causalOT::barycentric_projection(y ~ ., data = df, weight = weights, separate.samples.on = "z", debias = TRUE, niter = 2) # predictions, without new data undebiased_predictions <- predict(fit, source.sample = df$z) debiased_predictions <- predict(fit_d, source.sample = df$z) isTRUE(all.equal(unname(undebiased_predictions), df$y)) # FALSE isTRUE(all.equal(unname(debiased_predictions), df$y)) # TRUE }
print.dataHolder
## S3 method for class 'dataHolder' print(x, ...)## S3 method for class 'dataHolder' print(x, ...)
x |
dataHolder object |
... |
Not used |
Pareto-Smoothed Importance Sampling
PSIS(x, r_eff = NULL, ...) ## S4 method for signature 'numeric' PSIS(x, r_eff = NULL, ...) ## S4 method for signature 'causalWeights' PSIS(x, r_eff = NULL, ...) ## S4 method for signature 'list' PSIS(x, r_eff = NULL, ...) PSIS_diag(x, ...) ## S4 method for signature 'numeric' PSIS_diag(x, r_eff = NULL) ## S4 method for signature 'causalWeights' PSIS_diag(x, r_eff = NULL) ## S4 method for signature 'causalPSIS' PSIS_diag(x, ...) ## S4 method for signature 'list' PSIS_diag(x, r_eff = NULL) ## S4 method for signature 'psis' PSIS_diag(x, r_eff = NULL)PSIS(x, r_eff = NULL, ...) ## S4 method for signature 'numeric' PSIS(x, r_eff = NULL, ...) ## S4 method for signature 'causalWeights' PSIS(x, r_eff = NULL, ...) ## S4 method for signature 'list' PSIS(x, r_eff = NULL, ...) PSIS_diag(x, ...) ## S4 method for signature 'numeric' PSIS_diag(x, r_eff = NULL) ## S4 method for signature 'causalWeights' PSIS_diag(x, r_eff = NULL) ## S4 method for signature 'causalPSIS' PSIS_diag(x, ...) ## S4 method for signature 'list' PSIS_diag(x, r_eff = NULL) ## S4 method for signature 'psis' PSIS_diag(x, r_eff = NULL)
x |
For |
r_eff |
A vector of relative effective sample size with one estimate per observation. If providing
an object of class causalWeights, should be a list of vectors with one vector for each
sample. See psis() from the |
... |
Arguments passed to the psis() function. |
Acts as a wrapper to the psis() function from the loo package. It
is built to handle the data types found in this package. This method is preferred to the ESS()
function in causalOT since the latter is prone to error (infinite variances) but will not give good any indication that the estimates
are problematic.
For PSIS(), returns a list. See psis() from loo for a description of the outputs. Will give the log of the
smoothed weights in slot log_weights, and in the slot diagnostics, it will give
the pareto_k parameter (see the pareto-k-diagnostic page) and
the n_eff estimates. PSIS_diag() returns the diagnostic slot from an object of class "psis".
PSIS(numeric): numeric weights
PSIS(causalWeights): object of class causalWeights
PSIS(list): list of weights
PSIS_diag(numeric): numeric weights
PSIS_diag(causalWeights): object of class causalWeights diagnostics
PSIS_diag(causalPSIS): diagnostics from the output of a previous call to PSIS
PSIS_diag(list): a list of objects
PSIS_diag(psis): output of PSIS function
x <- runif(100) w <- x/sum(x) res <- PSIS(x = w, r_eff = 1) PSIS_diag(res)x <- runif(100) w <- x/sum(x) res <- PSIS(x = w, r_eff = 1) PSIS_diag(res)
Options for the SBW method
sbwOptions(delta = NULL, grid.length = 20L, nboot = 1000L, ...)sbwOptions(delta = NULL, grid.length = 20L, nboot = 1000L, ...)
delta |
A number or vector of tolerances for the balancing functions. Default is NULL which will use a grid search |
grid.length |
The number of values to try in the grid search |
nboot |
The number of bootstrap samples to run during the grid search. |
... |
Arguments passed on to osqpSettings() |
A list of class sbwOptions with slots
delta Delta values to try
grid.length The number of parameters to try
sumto1 Forced to be TRUE. Weights will always sum to 1.
nboot Number of bootstrap samples
solver.optionsA list with arguments passed to osqpSettings()
This method will balance functions of the covariates within some tolerance, . For these functions , we will desire
, where in this case we are targeting balance with the treatment group for the ATT. is the pooled standard deviation prior to balancing.
opts <- sbwOptions(delta = 0.1)opts <- sbwOptions(delta = 0.1)
Options for the SCM Method
scmOptions(...)scmOptions(...)
... |
Arguments passed to the osqpSettings() function which solves the problem. |
Options for the solver used in the optimization of the Synthetic Control Method of Abadie and Gardeazabal (2003).
A list with arguments to pass to osqpSettings()
opts <- scmOptions()opts <- scmOptions()
Summary diagnostics for causalWeights
print.summary_causalWeights
plot.summary_causalWeights
## S3 method for class 'causalWeights' summary( object, r_eff = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07, ... ) ## S3 method for class 'summary_causalWeights' print(x, ...) ## S3 method for class 'summary_causalWeights' plot(x, ...)## S3 method for class 'causalWeights' summary( object, r_eff = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07, ... ) ## S3 method for class 'summary_causalWeights' print(x, ...) ## S3 method for class 'summary_causalWeights' plot(x, ...)
object |
an object of class causalWeights |
r_eff |
The r_eff used in the PSIS calculation. See |
penalty |
The penalty parameter to use |
p |
The power of the Lp distance to use. Overridden by argument |
cost |
A user supplied cost function. Should take arguments |
debias |
Should debiased optimal transport distances be used. TRUE or FALSE |
online.cost |
Should the cost be calculated online? One of "auto","tensorized", or "online". |
diameter |
the diameter of the covariate space. Default is NULL. |
niter |
the number of iterations to run the optimal transport distances |
tol |
the tolerance for convergence for the optimal transport distances |
... |
Not used |
x |
an object of class "summary_causalWeights" |
The summary method returns an object of class "summary_causalWeights".
print(summary_causalWeights): print method
plot(summary_causalWeights): plot method
if(torch::torch_is_installed()) { n <- 2^6 p <- 6 overlap <- "high" design <- "A" estimand <- "ATE" #### get simulation functions #### original <- Hainmueller$new(n = n, p = p, design = design, overlap = overlap) original$gen_data() weights <- calc_weight(x = original, estimand = estimand, method = "Logistic") s <- summary(weights) plot(s) }if(torch::torch_is_installed()) { n <- 2^6 p <- 6 overlap <- "high" design <- "A" estimand <- "ATE" #### get simulation functions #### original <- Hainmueller$new(n = n, p = p, design = design, overlap = overlap) original$gen_data() weights <- calc_weight(x = original, estimand = estimand, method = "Logistic") s <- summary(weights) plot(s) }
Supported Methods
supported_methods()supported_methods()
A character list with supported methods. Note "COT" is the same as "Wasserstein". We provide the second name for backwards compatibility.
supported_methods()supported_methods()
Get the variance of a causalEffect
## S3 method for class 'causalEffect' vcov(object, ...)## S3 method for class 'causalEffect' vcov(object, ...)
object |
An object of class causalEffect |
... |
Passed on to the sandwich estimator if there is a model fit that supports one |
The variance of the treatment effect as a matrix
# set-up data set.seed(1234) data <- Hainmueller$new() data$gen_data() # calculate quantities weight <- calc_weight(data, estimand = "ATT", method = "Logistic") tx_eff <- estimate_effect(causalWeights = weight) vcov(tx_eff)# set-up data set.seed(1234) data <- Hainmueller$new() data$gen_data() # calculate quantities weight <- calc_weight(data, estimand = "ATT", method = "Logistic") tx_eff <- estimate_effect(causalWeights = weight) vcov(tx_eff)