Title: | Hardware-Accelerated Rerandomization for Improved Balance |
---|---|
Description: | Provides hardware-accelerated tools for performing rerandomization and randomization testing in experimental research. Using a 'JAX' backend, the package enables exact rerandomization inference even for large experiments with hundreds of billions of possible randomizations. Key functionalities include generating pools of acceptable rerandomizations based on covariate balance, conducting exact randomization tests, and performing pre-analysis evaluations to determine optimal rerandomization acceptance thresholds. The package supports various hardware acceleration frameworks including 'CPU', 'CUDA', and 'METAL', making it versatile across accelerated computing environments. This allows researchers to efficiently implement stringent rerandomization designs and conduct valid inference even with large sample sizes. The package is partly based on Jerzak and Goldstein (2023) <doi:10.48550/arXiv.2310.00861>. |
Authors: | Fucheng Warren Zhu [aut] , Aniket Sachin Kamat [aut] , Connor Jerzak [aut, cre] , Rebecca Goldstein [aut] |
Maintainer: | Connor Jerzak <[email protected]> |
License: | GPL-3 |
Version: | 0.2 |
Built: | 2025-01-14 23:18:47 UTC |
Source: | https://github.com/cjerzak/fastrerandomize-software |
A function to build the environment for fastrerandomize. Builds a conda environment in which 'JAX' and 'np' are installed. Users can also create a conda environment where 'JAX' and 'np' are installed themselves.
build_backend(conda_env = "fastrerandomize", conda = "auto")
build_backend(conda_env = "fastrerandomize", conda = "auto")
conda_env |
(default = |
conda |
(default = |
Invisibly returns NULL; this function is used for its side effects
of creating and configuring a conda environment for fastrerandomize
.
This function requires an Internet connection.
You can find out a list of conda Python paths via: Sys.which("python")
## Not run: # Create a conda environment named "fastrerandomize" # and install the required Python packages (jax, numpy, etc.) build_backend(conda_env = "fastrerandomize", conda = "auto") # If you want to specify a particular conda path: # build_backend(conda_env = "fastrerandomize", conda = "/usr/local/bin/conda") ## End(Not run)
## Not run: # Create a conda environment named "fastrerandomize" # and install the required Python packages (jax, numpy, etc.) build_backend(conda_env = "fastrerandomize", conda = "auto") # If you want to specify a particular conda path: # build_backend(conda_env = "fastrerandomize", conda = "/usr/local/bin/conda") ## End(Not run)
This function checks if 'Python' and 'JAX' can be accessed via 'reticulate'. If not, it returns 'NULL' and prints a message suggesting to run 'build_backend()'.
check_jax_availability(conda_env = "fastrerandomize", conda = "auto")
check_jax_availability(conda_env = "fastrerandomize", conda = "auto")
conda_env |
A character string specifying the name of the conda environment. Default is '"fastrerandomize"'. |
conda |
The path to a conda executable, or '"auto"'. Default is '"auto"'. |
Returns ‘TRUE' (invisibly) if both ’Python' and 'JAX' are available; otherwise returns 'NULL'.
## Not run: check_jax_availability() ## End(Not run)
## Not run: check_jax_availability() ## End(Not run)
Create an S3 object of class fastrerandomize_randomizations
that stores
the randomizations (and optionally balance statistics) generated by
functions such as generate_randomizations
.
fastrerandomize_class( randomizations, balance = NULL, fastrr_env = NULL, call = NULL )
fastrerandomize_class( randomizations, balance = NULL, fastrr_env = NULL, call = NULL )
randomizations |
A matrix or array where each row (or slice) represents one randomization. |
balance |
A numeric vector or similar object holding balance statistics
for each randomization, or |
fastrr_env |
Associated |
call |
The function call, if you wish to store it for reference (optional). |
An object of class fastrerandomize_randomizations
.
Constructor for fastrerandomize randomization test objects
fastrerandomize_test(p_value, FI, tau_obs, fastrr_env = NULL, call = NULL, ...)
fastrerandomize_test(p_value, FI, tau_obs, fastrr_env = NULL, call = NULL, ...)
p_value |
A numeric value representing the p-value of the test. |
FI |
A numeric vector (length 2) representing the fiducial interval, or |
tau_obs |
A numeric value (or vector) representing the estimated treatment effect. |
fastrr_env |
Associated 'fastrr_env' environment. |
call |
An optional function call, stored for reference. |
... |
Other slots you may want to store (e.g. additional diagnostics). |
An object of class fastrerandomize_test
.
This function generates randomizations for experimental design using either exact enumeration or Monte Carlo sampling methods. It provides a unified interface to both approaches while handling memory and computational constraints appropriately.
generate_randomizations( n_units, n_treated, X = NULL, randomization_accept_prob, threshold_func = NULL, max_draws = 10^6, batch_size = 1000, randomization_type = "monte_carlo", approximate_inv = TRUE, file = NULL, return_type = "R", verbose = TRUE, conda_env = "fastrerandomize", conda_env_required = TRUE )
generate_randomizations( n_units, n_treated, X = NULL, randomization_accept_prob, threshold_func = NULL, max_draws = 10^6, batch_size = 1000, randomization_type = "monte_carlo", approximate_inv = TRUE, file = NULL, return_type = "R", verbose = TRUE, conda_env = "fastrerandomize", conda_env_required = TRUE )
n_units |
An integer specifying the total number of experimental units. |
n_treated |
An integer specifying the number of units to be assigned to treatment. |
X |
A numeric matrix of covariates used for balance checking. Cannot be |
randomization_accept_prob |
A numeric value between 0 and 1 specifying the probability threshold for accepting randomizations based on balance. |
threshold_func |
A 'JAX' function that computes a balance measure for each randomization. Only used for Monte Carlo sampling. |
max_draws |
An integer specifying the maximum number of randomizations to draw in Monte Carlo sampling. |
batch_size |
An integer specifying batch size for Monte Carlo processing. |
randomization_type |
A string specifying the type of randomization: either |
approximate_inv |
A logical value indicating whether to use an approximate inverse
(diagonal of the covariance matrix) instead of the full matrix inverse when computing
balance metrics. This can speed up computations for high-dimensional covariates.
Default is |
file |
A string specifying where to save candidate randomizations (if saving, not returning). |
return_type |
A string specifying the format of the returned randomizations and balance
measures. Allowed values are |
verbose |
A logical value indicating whether to print progress information. Default is |
conda_env |
A character string specifying the name of the conda environment to use
via |
conda_env_required |
A logical indicating whether the specified conda environment
must be strictly used. If |
The function supports two methods of generating randomizations:
Exact enumeration: Generates all possible randomizations (memory intensive but exact).
Monte Carlo sampling: Generates randomizations through sampling (more memory efficient).
For large problems (e.g., X with >20 rows), Monte Carlo sampling is recommended.
Returns an S3 object with slots:
assignments
An array where each row represents one possible treatment assignment vector containing the accepted randomizations.
balance_measures
A numeric vector containing the balance measure for each corresponding randomization.
fastrr_env
The fastrerandomize environment.
file_output
If file is specified, results are saved to the given file path instead of being returned.
generate_randomizations_exact
for the exact enumeration method.
generate_randomizations_mc
for the Monte Carlo sampling method.
## Not run: # Generate synthetic data X <- matrix(rnorm(20*5), 20, 5) # Generate randomizations using exact enumeration RandomizationSet_Exact <- generate_randomizations( n_units = nrow(X), n_treated = round(nrow(X)/2), X = X, randomization_accept_prob=0.1, randomization_type="exact") # Generate randomizations using Monte Carlo sampling RandomizationSet_MC <- generate_randomizations( n_units = nrow(X), n_treated = round(nrow(X)/2), X = X, randomization_accept_prob = 0.1, randomization_type = "monte_carlo", max_draws = 100000, batch_size = 1000) ## End(Not run)
## Not run: # Generate synthetic data X <- matrix(rnorm(20*5), 20, 5) # Generate randomizations using exact enumeration RandomizationSet_Exact <- generate_randomizations( n_units = nrow(X), n_treated = round(nrow(X)/2), X = X, randomization_accept_prob=0.1, randomization_type="exact") # Generate randomizations using Monte Carlo sampling RandomizationSet_MC <- generate_randomizations( n_units = nrow(X), n_treated = round(nrow(X)/2), X = X, randomization_accept_prob = 0.1, randomization_type = "monte_carlo", max_draws = 100000, batch_size = 1000) ## End(Not run)
Generates all possible treatment assignments for a completely randomized experiment, optionally filtering them based on covariate balance criteria. The function can generate either all possible randomizations or a subset that meets specified balance thresholds using Hotelling's T-squared statistic.
generate_randomizations_exact( n_units, n_treated, X = NULL, randomization_accept_prob = 1, approximate_inv = TRUE, threshold_func = NULL, verbose = TRUE, conda_env = "fastrerandomize", conda_env_required = TRUE )
generate_randomizations_exact( n_units, n_treated, X = NULL, randomization_accept_prob = 1, approximate_inv = TRUE, threshold_func = NULL, verbose = TRUE, conda_env = "fastrerandomize", conda_env_required = TRUE )
n_units |
An integer specifying the total number of experimental units |
n_treated |
An integer specifying the number of units to be assigned to treatment |
X |
A numeric matrix of covariates where rows represent units and columns
represent different covariates. Default is |
randomization_accept_prob |
A numeric value between 0 and 1 specifying the quantile threshold for accepting randomizations based on balance statistics. Default is 1 (accept all randomizations). |
approximate_inv |
A logical value indicating whether to use an approximate inverse
(diagonal of the covariance matrix) instead of the full matrix inverse when computing
balance metrics. This can speed up computations for high-dimensional covariates.
Default is |
threshold_func |
A function that calculates balance statistics for candidate
randomizations. Default is |
verbose |
A logical value indicating whether to print progress information. Default is |
conda_env |
A character string specifying the name of the conda environment to use
via |
conda_env_required |
A logical indicating whether the specified conda environment
must be strictly used. If |
The function works in two main steps: 1. Generates all possible combinations of treatment assignments given n_units and n_treated 2. If covariates (X) are provided, filters these combinations based on balance criteria using the specified threshold function
The balance filtering process uses Hotelling's T-squared statistic by default to measure multivariate balance between treatment and control groups. Randomizations are accepted if their balance measure is below the specified quantile threshold.
The function returns a list with two elements:
candidate_randomizations
: an array of randomization vectors
M_candidate_randomizations
: an array of their balance measures.
This function requires 'JAX' and 'NumPy' to be installed and accessible through the reticulate package.
Hotelling, H. (1931). The generalization of Student's ratio. The Annals of Mathematical Statistics, 2(3), 360-378.
generate_randomizations
for full randomization generation function.
generate_randomizations_mc
for the Monte Carlo version.
## Not run: # Generate synthetic data X <- matrix(rnorm(60), nrow = 10) # 10 units, 6 covariates # Generate balanced randomizations with covariates BalancedRandomizations <- generate_randomizations_exact( n_units = 10, n_treated = 5, X = X, randomization_accept_prob = 0.25 # Keep top 25% most balanced ) ## End(Not run)
## Not run: # Generate synthetic data X <- matrix(rnorm(60), nrow = 10) # 10 units, 6 covariates # Generate balanced randomizations with covariates BalancedRandomizations <- generate_randomizations_exact( n_units = 10, n_treated = 5, X = X, randomization_accept_prob = 0.25 # Keep top 25% most balanced ) ## End(Not run)
This function performs sampling with replacement to generate randomizations in a memory-efficient way. It processes randomizations in batches to avoid memory issues and filters them based on covariate balance. The function uses JAX for fast computation and memory management.
generate_randomizations_mc( n_units, n_treated, X, randomization_accept_prob = 1, threshold_func = NULL, max_draws = 1e+05, batch_size = 1000, approximate_inv = TRUE, verbose = TRUE, conda_env = "fastrerandomize", conda_env_required = TRUE )
generate_randomizations_mc( n_units, n_treated, X, randomization_accept_prob = 1, threshold_func = NULL, max_draws = 1e+05, batch_size = 1000, approximate_inv = TRUE, verbose = TRUE, conda_env = "fastrerandomize", conda_env_required = TRUE )
n_units |
An integer specifying the total number of experimental units. |
n_treated |
An integer specifying the number of units to be assigned to treatment. |
X |
A numeric matrix of covariates used for balance checking. Cannot be NULL. |
randomization_accept_prob |
A numeric value between 0 and 1 specifying the probability threshold for accepting randomizations based on balance. Default is 1 |
threshold_func |
A JAX function that computes a balance measure for each randomization. Must be vectorized using |
max_draws |
An integer specifying the maximum number of randomizations to draw. |
batch_size |
An integer specifying how many randomizations to process at once. Lower values use less memory but may be slower. |
approximate_inv |
A logical value indicating whether to use an approximate inverse
(diagonal of the covariance matrix) instead of the full matrix inverse when computing
balance metrics. This can speed up computations for high-dimensional covariates.
Default is |
verbose |
A logical value indicating whether to print detailed information about batch processing progress, and GPU memory usage. Default is |
conda_env |
A character string specifying the name of the conda environment to use
via |
conda_env_required |
A logical indicating whether the specified conda environment
must be strictly used. If |
The function works by:
Generating batches of random permutations.
Computing balance measures for each permutation using the provided threshold function.
Keeping only the top permutations that meet the acceptance probability threshold.
Managing memory by clearing unused objects and caches between batches.
The function uses smaller data types (int8, float16) where possible to reduce memory usage. It also includes assertions to verify array shapes and dimensions throughout.
The function returns a list with two elements:
candidate_randomizations
: an array of randomization vectors
M_candidate_randomizations
: an array of their balance measures.
generate_randomizations
for full randomization generation function.
generate_randomizations_exact
for the exact version.
## Not run: # Generate synthetic data X <- matrix(rnorm(100*5), 100, 5) # 5 covariates # Generate 1000 randomizations for 100 units with 50 treated rand_less_strict <- generate_randomizations_mc( n_units = 100, n_treated = 50, X = X, randomization_accept_prob=0.01, max_draws = 100000, batch_size = 1000) # Use a stricter balance criterion rand_more_strict <- generate_randomizations_mc( n_units = 100, n_treated = 50, X = X, randomization_accept_prob=0.001, max_draws = 1000000, batch_size = 1000) ## End(Not run)
## Not run: # Generate synthetic data X <- matrix(rnorm(100*5), 100, 5) # 5 covariates # Generate 1000 randomizations for 100 units with 50 treated rand_less_strict <- generate_randomizations_mc( n_units = 100, n_treated = 50, X = X, randomization_accept_prob=0.01, max_draws = 100000, batch_size = 1000) # Use a stricter balance criterion rand_more_strict <- generate_randomizations_mc( n_units = 100, n_treated = 50, X = X, randomization_accept_prob=0.001, max_draws = 1000000, batch_size = 1000) ## End(Not run)
Plots the observed treatment effect and, if available, the fiducial interval on a horizontal axis.
## S3 method for class 'fastrerandomize_randomizations' plot(x, ...)
## S3 method for class 'fastrerandomize_randomizations' plot(x, ...)
x |
An object of class |
... |
Further graphical parameters passed to |
No return value. This function is called for the side effect of
generating a histogram of the accepted balance measures of object with class fastrerandomize_randomizations
.
Plots a simple visualization of the observed effect and the fiducial interval (if present) on a horizontal axis.
## S3 method for class 'fastrerandomize_test' plot(x, ...)
## S3 method for class 'fastrerandomize_test' plot(x, ...)
x |
An object of class |
... |
Further graphical parameters passed to |
No output returned. Performs side effect of plotting fastrerandomize_test
class objects.
Print method for fastrerandomize_randomizations objects
## S3 method for class 'fastrerandomize_randomizations' print(x, ...)
## S3 method for class 'fastrerandomize_randomizations' print(x, ...)
x |
An object of class |
... |
Further arguments passed to or from other methods. |
Prints an object of class fastrerandomize_randomizations
.
Print method for fastrerandomize_test objects
## S3 method for class 'fastrerandomize_test' print(x, ...)
## S3 method for class 'fastrerandomize_test' print(x, ...)
x |
An object of class |
... |
Further arguments passed to or from other methods. |
No return value, prints object of class fastrerandomize_test
.
This function prints messages prefixed with the current timestamp in a standardized format. Messages can be suppressed using the quiet parameter.
print2(text, quiet = FALSE)
print2(text, quiet = FALSE)
text |
A character string containing the message to be printed. |
quiet |
A logical value indicating whether to suppress output. Default is |
The function prepends the current timestamp in "YYYY-MM-DD HH:MM:SS" format to the provided message.
No return value, called for side effect of printing with timestamp.
Sys.time
for the underlying timestamp functionality.
# Print a basic message with timestamp print2("Processing started") # Suppress output print2("This won't show", quiet = TRUE) # Use in a loop for(i in 1:3) { print2(sprintf("Processing item %d", i)) }
# Print a basic message with timestamp print2("Processing started") # Suppress output print2("This won't show", quiet = TRUE) # Use in a loop for(i in 1:3) { print2(sprintf("Processing item %d", i)) }
Data from a field experiment studying moral hazard in tenancy contracts in agriculture.
After subsetting, this dataset includes observations on 968 experimental units with the following variables of interest: household composition, treatment assignment, and agricultural outcomes.
data(QJEData)
data(QJEData)
A data frame with 968 rows and 7 columns:
Numeric (integer). Number of children in the household. Larger numbers may reflect increased household labor needs and different investment or effort incentives.
Numeric/binary. Whether the household head is currently married (1) or not (0). Marital status may influence decision-making and risk preferences in farming.
Numeric (integer). Household size. Differences in family labor availability or consumption needs can influence effort levels and thus relate to moral hazard in production decisions.
Numeric. The ratio of adult men to adult women in the household. Imbalances in the male–female ratio can affect labor division and investment decisions.
Numeric/binary. Primary treatment indicator (e.g., whether a farmer is offered a specific tenancy contract or cost-sharing arrangement).
Numeric. Crop yield per square meter (e.g., kilograms of output per square meter). This is a principal outcome measure for evaluating productivity and treatment impact on farm performance.
Numeric/binary. Indicator for whether fertilizer was used (1) or not (0). This measures input investment—a key mechanism in moral hazard models (farmers may alter input use under different contracts).
Burchardi, K.B., Ghatak, M., & Johanssen, A. (2019). Moral hazard: Experimental evidence from tenancy contracts. The Quarterly Journal of Economics, 134(1), 281-347.
Fast randomization test
randomization_test( obsW = NULL, obsY = NULL, X = NULL, alpha = 0.05, candidate_randomizations = NULL, candidate_randomizations_array = NULL, n0_array = NULL, n1_array = NULL, randomization_accept_prob = 1, findFI = FALSE, c_initial = 2, max_draws = 10^6, batch_size = 10^5, randomization_type = "monte_carlo", approximate_inv = TRUE, file = NULL, verbose = TRUE, conda_env = "fastrerandomize", conda_env_required = TRUE )
randomization_test( obsW = NULL, obsY = NULL, X = NULL, alpha = 0.05, candidate_randomizations = NULL, candidate_randomizations_array = NULL, n0_array = NULL, n1_array = NULL, randomization_accept_prob = 1, findFI = FALSE, c_initial = 2, max_draws = 10^6, batch_size = 10^5, randomization_type = "monte_carlo", approximate_inv = TRUE, file = NULL, verbose = TRUE, conda_env = "fastrerandomize", conda_env_required = TRUE )
obsW |
A numeric vector where |
obsY |
An optional numeric vector of observed outcomes. If not provided, the function assumes a NULL value. |
X |
A numeric matrix of covariates. |
alpha |
The significance level for the test. Default is |
candidate_randomizations |
A numeric matrix of candidate randomizations. |
candidate_randomizations_array |
An optional 'JAX' array of candidate randomizations. If not provided, the function coerces |
n0_array |
An optional array specifying the number of control units. |
n1_array |
An optional array specifying the number of treated units. |
randomization_accept_prob |
An numeric scalar or vector of probabilities for accepting each randomization. |
findFI |
A logical value indicating whether to find the fiducial interval. Default is FALSE. |
c_initial |
A numeric value representing the initial criterion for the randomization. Default is |
max_draws |
An integer specifying the maximum number of candidate randomizations
to generate (or to consider) for the test when |
batch_size |
An integer specifying the batch size for Monte Carlo sampling.
Batches are processed one at a time for memory efficiency. Default is |
randomization_type |
A string specifying the type of randomization for the test. Allowed values are "exact" or "monte_carlo". Default is "monte_carlo". |
approximate_inv |
A logical value indicating whether to use an approximate inverse
(diagonal of the covariance matrix) instead of the full matrix inverse when computing
balance metrics. This can speed up computations for high-dimensional covariates.
Default is |
file |
A character string specifying the path (including filename) where candidate
randomizations will be saved or loaded from. If |
verbose |
A logical value indicating whether to print progress information. Default is |
conda_env |
A character string specifying the name of the conda environment to use
via |
conda_env_required |
A logical indicating whether the specified conda environment
must be strictly used. If |
Returns an S3 object with slots:
p_value
A numeric value or vector representing the p-value of the test (or the expected p-value under the prior structure specified in the function inputs).
FI
A numeric vector representing the fiducial interval if findFI=TRUE
.
tau_obs
A numeric value or vector representing the estimated treatment effect(s).
fastrr_env
The fastrerandomize environment.
Zhang, Y. and Zhao, Q., 2023. What is a randomization test?. Journal of the American Statistical Association, 118(544), pp.2928-2942.
generate_randomizations
for randomization generation function.
## Not run: # A small synthetic demonstration with 6 units, 3 treated and 3 controls: # Generate pre-treatment covariates X <- matrix(rnorm(24*2), ncol = 2) # Generate candidate randomizations RandomizationSet_MC <- generate_randomizations( n_units = nrow(X), n_treated = round(nrow(X)/2), X = X, randomization_accept_prob = 0.1, randomization_type = "monte_carlo", max_draws = 100000, batch_size = 1000 ) # Generate outcome W <- RandomizationSet_MC$randomizations[1,] obsY <- rnorm(nrow(X), mean = 2 * W) # Perform randomization test results_base <- randomization_test( obsW = W, obsY = obsY, X = X, candidate_randomizations = RandomizationSet_MC$randomizations, ) print(results_base) # Perform randomization test result_fi <- randomization_test( obsW = W, obsY = obsY, X = X, candidate_randomizations = RandomizationSet_MC$randomizations, findFI = TRUE ) print(result_fi) ## End(Not run)
## Not run: # A small synthetic demonstration with 6 units, 3 treated and 3 controls: # Generate pre-treatment covariates X <- matrix(rnorm(24*2), ncol = 2) # Generate candidate randomizations RandomizationSet_MC <- generate_randomizations( n_units = nrow(X), n_treated = round(nrow(X)/2), X = X, randomization_accept_prob = 0.1, randomization_type = "monte_carlo", max_draws = 100000, batch_size = 1000 ) # Generate outcome W <- RandomizationSet_MC$randomizations[1,] obsY <- rnorm(nrow(X), mean = 2 * W) # Perform randomization test results_base <- randomization_test( obsW = W, obsY = obsY, X = X, candidate_randomizations = RandomizationSet_MC$randomizations, ) print(results_base) # Perform randomization test result_fi <- randomization_test( obsW = W, obsY = obsY, X = X, candidate_randomizations = RandomizationSet_MC$randomizations, findFI = TRUE ) print(result_fi) ## End(Not run)
Summary method for fastrerandomize_randomizations objects
## S3 method for class 'fastrerandomize_randomizations' summary(object, ...)
## S3 method for class 'fastrerandomize_randomizations' summary(object, ...)
object |
An object of class |
... |
Further arguments passed to or from other methods. |
A list with summary statistics, printed by default.
Summary method for fastrerandomize_test objects
## S3 method for class 'fastrerandomize_test' summary(object, ...)
## S3 method for class 'fastrerandomize_test' summary(object, ...)
object |
An object of class |
... |
Further arguments passed to or from other methods. |
Returns an (invisible) list with a summary of fastrerandomize_test
class objects.
Data from a re-analysis of the Youth Opportunities Program anti-poverty RCT in Uganda, with satellite imagery neural representations linked to RCT units.
data(YOPData)
data(YOPData)
A list containing two data frames:
Treatment, outcome, and geolocation information
CLIP-RSICD neural embeddings of satellite imagery
Blattman, C., Fiala, N. and Martinez, S. (2020). The Long-term Impacts of Grants on Poverty: Nine-year Evidence from Uganda's Youth Opportunities Program. American Economic Review: Insights, 2(3), 287-304.
Jerzak, C.T., Johansson, F.D. and Daoud, A. (2023). Image-based Treatment Effect Heterogeneity. Conference on Causal Learning and Reasoning, 531-552. PMLR.