One advantage of the “StratifiedMedicine”” package is the flexibility to input user-created functions/models. This facilitates faster testing and experimentation. First, let’s simulate the continuous data again.
library(StratifiedMedicine)
= generate_subgrp_data(family="gaussian")
dat_ctns = dat_ctns$Y
Y = dat_ctns$X # 50 covariates, 46 are noise variables, X1 and X2 are truly predictive
X = dat_ctns$A # binary treatment, 1:1 randomized A
Next, before we illustrate how to implement user-specific models in PRISM, let’s highlight the key outputs at each step.
Model | Required.Outputs | Description |
---|---|---|
filter | filter.vars | Variables that pass filter |
ple | list(mod,pred.fun) | Model fit(s) and prediction function |
submod | list(mod,pred.fun) | Model fit(s) and prediction function |
param | param.dat | Treatment Effect Estimates |
For the filter model (“filter_train”), the only required output is a vector of variable names that pass the filter (for example: covariates with non-zero coefficients in elastic net model). For the patient-level estimate model (“ple_train”) and subgroup model (“submod_train”), the required outputs are the model fit(s) and an associated prediction function. The prediction function can also be swapped with pre-computed predictions (details below). Lastly, for parameter estimation (param), the only required output is “param.dat”, a data-frame containing point-estimates/SEs/CIs.
The template filter function is:
= function(Y, A, X, ...){
filter_template # Step 1: Fit Filter Model #
<- # model call
mod # Step 2: Extract variables that pass the filter #
<- # depends on mod fit
filter.vars # Return model fit and filtered variables #
= list(mod=mod, filter.vars=filter.vars)
res return( res )
}
Note that the filter uses the observed data (Y,A,X), which are required inputs, and outputs an object called “filter.vars.” This needs to contain the variable names of the variables that pass the filtering step. For example, consider the lasso:
= function(Y, A, X, lambda="lambda.min", family="gaussian", ...){
filter_lasso require(glmnet)
## Model matrix X matrix #
= model.matrix(~. -1, data = X )
X
##### Elastic Net ##
set.seed(6134)
if (family=="survival") { family = "cox" }
<- cv.glmnet(x = X, y = Y, nlambda = 100, alpha=1, family=family)
mod
### Extract filtered variable based on lambda ###
<- coef(mod, s = lambda)[,1]
VI = VI[-1]
VI = names(VI[VI!=0])
filter.vars return( list(filter.vars=filter.vars) )
}
An option to change lambda, which can change which variables remain after filtering (lambda.min keeps more, lambda.1se keeps less), while not required, is also included. This can then be adjusted through the “hyper” and “filter.hyper” arguments in “filter_train” and “PRISM” respectively.
The template ple function is:
<- function(Y, A, X, ...){
ple_template # Step 1: Fit PLE Model #
# for example: Estimate E(Y|A=1,X), E(Y|A=0,X), E(Y|A=1,X)-E(Y|A=0,X)
<- # ple model call
mod # mod = list(mod0=mod0, mod1=mod1) # If multiple fitted models, combine into list
# Step 2: Predictions
# Option 1: Create a Prediction Function #
<- function(mod, X, ...){
pred.fun <- # data-frame of predictions
mu_hat return(mu_hat)
}# Option 2: Directly Output Predictions (here, we still use pred.fun) #
<- pred.fun(mod, X)
mu_train <- pred.fun(mod, Xtest)
mu_test
# Return model fits and pred.fun (or just mu_train/mu_test) #
<- list(mod=mod, pred.fun=pred.fun, mu_train=mu_train, mu_test=mu_test)
res return( res )
}
For “ple_train”, the only required arguments are the observed data (Y, X). The only required outputs are mod (fitted models(s)) and a prediction function or pre-computed predictions in the training/test set (mu_train, mu_test). If the training/test set predictions are provided, this will be used instead of the prediction function. However, certain features in “StratifiedMedicine”, such as PDP plots (“plot_dependence”), cannot be utilized without providing a prediction funcion. In the example below, we set up a simple wraper for random forest based predictions. This base-learner can then be combined with the default meta-learners (meta=“T-learner”, “X-learner”, “S-learner”) to obtain predictions for specific exposures or treatment differences.
= function(Y, X, mtry=5, ...){
ple_ranger_mtry require(ranger)
= data.frame(Y=Y, X)
train <- ranger(Y ~ ., data = train, seed=1, mtry = mtry)
mod = list(mod=mod)
mod <- function(mod, X, ...){
pred.fun <- predict(mod$mod, X)$predictions
mu_hat <- mu_hat
mu_hat return(mu_hat)
}= list(mod=mod, pred.fun=pred.fun)
res return(res)
}
The template submod function is:
<- function(Y, A, X, Xtest, mu_train, ...){
submod_template # Step 1: Fit subgroup model #
<- # model call
mod # Step 2: Predictions #
# Option 1: Create Prediction Function #
<- function(mod, X=NULL, ...){
pred.fun <- # Predict subgroup assignment
Subgrps return( list(Subgrps=Subgrps) )
}# Option 2: Output Subgroups for train/test (here we use pred.fun)
= pred.fun(mod, X)
Subgrps.train = pred.fun(mod, X)
Subgrps.test #Return fit and pred.fun (or just Subgrps.train/Subgrps.test)
<- list(mod=mod, pred.fun=pred.fun, Subgrps.train=Subgrps.train,
res Subgrps.test=Subgrps.test)
return(res)
}
For the “submod” model, the only required arguments are the observed data (Y,A,X). “mu_train” (based on “ple_train” predictions) can also be passed through. The only required outputs are mod (fitted models(s)) and a prediction function or pre-computed subgroup predictions in the training/test set (Subgrps.train, Subgrps.test). In the example below, consider a modified version of “submod_lmtree” where we search for predictive effects only. By default, “submod_lmtree” searches for prognostic and/or predictive effects.
= function(Y, A, X, mu_train, ...){
submod_lmtree_pred require(partykit)
## Fit Model ##
<- lmtree(Y~A | ., data = X, parm=2) ##parm=2 focuses on treatment interaction #
mod <- function(mod, X=NULL, type="subgrp"){
pred.fun <- NULL
Subgrps <- as.numeric( predict(mod, type="node", newdata = X) )
Subgrps return( list(Subgrps=Subgrps) )
}## Return Results ##
return(list(mod=mod, pred.fun=pred.fun))
}
The template param function is:
<- function(Y, A, X, mu_hat, alpha,...){
param_template # Key Outputs: Subgroup specific and overall parameter estimates
<- # Call parameter model #
mod # Extract estimates/variability and combine #
<- data.frame(n=n, estimand="mu_1-mu_0",
param.dat est=est, SE=SE, LCL=LCL, UCL=UCL, pval=pval)
return(param.dat)
}
For the parameter model, key arguments are (Y, A) (observed outcome/treatment) and alpha (nominal type I error for CIs). Other inputs can include mu_hat (ple predictions) and the covariate space (X) if needed for parameter estimation. The only required output is “param.dat”, which contains parameter estimates/variability metrics. For all PRISM functionality to work, param.dat should contain column names of “est” (parameter estimate), “SE” (standard error), and “LCL”/“UCL” (lower and upper confidence limits). It is recommended to include an “estimand” column for labeling purpose. In the example below, M-estimation models are fit for each subgroup and overall.
### Robust linear Regression: E(Y|A=1) - E(Y|A=0) ###
= function(Y, A, alpha, ...){
param_rlm require(MASS)
= data.frame(Y=Y,A=A)
indata = tryCatch( rlm(Y ~ A , data=indata),
rlm.mod error = function(e) "param error" )
= dim(indata)[1]
n = summary(rlm.mod)$coefficients[2,1]
est = summary(rlm.mod)$coefficients[2,2]
SE = est-qt(1-alpha/2, n-1)*SE
LCL = est+qt(1-alpha/2, n-1)*SE
UCL = 2*pt(-abs(est/SE), df=n-1)
pval <- data.frame(N= n, estimand = "mu_1-mu_0",
param.dat est=est, SE=SE, LCL=LCL, UCL=UCL, pval=pval)
return(param.dat)
}
Finally, let’s input these user-specific functions into each step along with combining the components with PRISM. Note that the meta=“X-learner” is used for estimating patient-level treatment estimates.
# Individual Steps #
<- filter_train(Y, A, X, filter="filter_lasso")
step1 <- X[,colnames(X) %in% step1$filter.vars]
X.star <- ple_train(Y, A, X.star, ple = "ple_ranger_mtry", meta="X-learner")
step2 plot_ple(step2)
<- submod_train(Y, A, X.star, submod = "submod_lmtree_pred", param="param_rlm")
step3 plot_tree(step3)
# Combine all through PRISM #
= PRISM(Y=Y, A=A, X=X, family="gaussian", filter="filter_lasso",
res_user1 ple = "ple_ranger_mtry", meta="X-learner",
submod = "submod_lmtree_pred",
param="param_rlm")
$filter.vars
res_user1#> [1] "X1" "X2" "X3" "X5" "X7" "X8" "X10" "X12" "X16" "X18" "X24" "X26"
#> [13] "X31" "X40" "X46" "X50"
plot(res_user1, type="PLE:waterfall")