## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.dim = c(6, 4)
)

suppressPackageStartupMessages({
    library(BART)
    library(tidytreatment)
    library(dplyr)
    library(tidybayes)
    library(ggplot2)
  })
  
  # load pre-computed data and model
  sim <- suhillsim1
  te_model <- bartmodel1
  
  # pre compute
  posterior_treat_eff <- treatment_effects(te_model, treatment = "z", newdata = sim$data) 
  posterior_treat_eff_on_treated <- treatment_effects(te_model, treatment = "z", newdata = sim$dat, subset = "treated") 
  

## ----load-data-print, echo = TRUE, eval = FALSE-------------------------------
# 
# # load packages
# library(BART)
# library(tidytreatment)
# library(dplyr)
# library(tidybayes)
# library(ggplot2)
# 
# # set seed so vignette is reproducible
# set.seed(101)
# 
# # simulate data
# sim <- simulate_su_hill_data(n = 100, treatment_linear = FALSE,  omega = 0, add_categorical = TRUE,
#                              coef_categorical_treatment = c(0,0,1),
#                              coef_categorical_nontreatment = c(-1,0,-1)
# )
# 
# # binary in this example
# 
# sim$data$y <- as.integer(sim$data$y > 2.7)
# 

## ----data-summary, echo = TRUE, eval = TRUE-----------------------------------

# non-treated vs treated counts:
table(sim$data$z)

dat <- sim$data
# a selection of data
dat %>% select(y, z, c1, x1:x3) %>% head()


## ----run-bart, echo = TRUE, eval = FALSE--------------------------------------
# 
# # STEP 1 VS Model: Regress y ~ covariates
# var_select_bart <- pbart(x.train = select(dat,-y,-z),
#                          y.train = pull(dat, y),
#                          sparse = TRUE,
#                          nskip = 2000,
#                          ndpost = 5000)
# 
# # STEP 2: Variable selection
#   # Select most important vars from y ~ covariates model
#   # Note: This is an overly simple selection mechanism.
#   # See package {bartMan} vignettes and https://doi.org/10.52933/jdssv.v4i1.79
#   # for discussion of variable importance methods.
# covar_ranking <- covariate_importance(var_select_bart)
# var_select <- covar_ranking %>%
#   filter(avg_inclusion >= quantile(avg_inclusion, 0.5)) %>%
#   pull(variable)
# 
# # change categorical variables to just one variable
# var_select <- unique(gsub("c1[1-3]$","c1", var_select))
# 
# var_select
# 
# # STEP 3 PS Model: Regress z ~ selected covariates
#   # BART::pbart is for probit regression
# prop_bart <- pbart(
#   x.train = select(dat, all_of(var_select)),
#   y.train = pull(dat, z),
#   nskip = 2000,
#   ndpost = 5000
# )
# 
# # store propensity score in data
# dat$prop_score <-  prop_bart$prob.train.mean
# 
# # Step 4 TE Model: Regress y ~ z + covariates + propensity score
# te_model <- pbart(
#   x.train = select(dat,-y),
#   y.train = pull(dat, y),
#   nskip = 10000L,
#   ndpost = 200L, #*
#   keepevery = 100L #*
# )
# 
# #* The posterior samples are kept small to manage size on CRAN
# 

## ----tidy-bart-fit, echo=TRUE, cache=FALSE------------------------------------

posterior_fitted <- epred_draws(te_model, value = "fit", scale = "prob", include_newdata = FALSE)
# include_newdata = FALSE, avoids returning the newdata with the fitted values
# as it is so large. newdata argument must be specified for this option in BART models. 
# The `.row` variable makes sure we know which row in the newdata the fitted
# value came from (if we dont include the data in the result).

posterior_fitted


## ----tidy-bart-pred, eval=FALSE, echo=TRUE, cache=FALSE-----------------------
# 
# # Function to tidy predicted draws also, this adds random normal noise by default
# posterior_pred <- predicted_draws(te_model, include_newdata = FALSE)
# 

## ----plot-tidy-bart, echo=TRUE, cache=FALSE-----------------------------------

treatment_var_and_c1 <- 
  dat %>% 
  select(z,c1) %>%
  mutate(.row = 1:n(), z = as.factor(z))

posterior_fitted %>%
  left_join(treatment_var_and_c1, by = ".row") %>%
  ggplot() + 
  stat_halfeye(aes(x = z, y = fit)) + 
  facet_wrap(~c1, labeller = as_labeller( function(x) paste("c1 =",x) ) ) +
  xlab("Treatment (z)") + ylab("Posterior predicted value") +
  theme_bw() + ggtitle("Effect of treatment with 'c1' on posterior fitted values")


## ----post-treatment, eval = FALSE---------------------------------------------
# 
# # sample based (using data from fit) conditional treatment effects, posterior draws
# posterior_treat_eff <-
#   treatment_effects(te_model, treatment = "z", scale = "prob", newdata = dat)
# 

## ----cates-hist, echo=TRUE, cache=FALSE---------------------------------------

# Histogram of treatment effect (all draws)
posterior_treat_eff %>% 
  ggplot() +
  geom_histogram(aes(x = cte), binwidth = 0.1, colour = "white") + 
  theme_bw() + ggtitle("Histogram of treatment effect (all draws)")

# Histogram of treatment effect (median for each subject)
posterior_treat_eff %>% summarise(cte_hat = median(cte)) %>%
  ggplot() +
  geom_histogram(aes(x = cte_hat), binwidth = 0.1, colour = "white") + 
  theme_bw() + ggtitle("Histogram of treatment effect (median for each subject)")


## ----att-ate, eval=FALSE------------------------------------------------------
# # get the ATE and ATT directly:
# 
# posterior_ate <- tidy_ate(te_model, treatment = "z", scale = "prob", newdata = dat)
# posterior_att <- tidy_att(te_model, treatment = "z", scale = "prob", newdata = dat)
# 

## ----ate-trace-setup, eval = TRUE, echo = FALSE-------------------------------

posterior_ate <- posterior_treat_eff %>% group_by(.chain, .iteration, .draw) %>%
  summarise(ate = mean(cte), .groups = "drop")


## ----ate-trace, eval=TRUE, echo=TRUE------------------------------------------

posterior_ate %>% ggplot(aes(x = .draw, y = ate)) +
  geom_line() +
  theme_bw() + 
  ggtitle("Trace plot of ATE")


## ----post-te-treated, echo=TRUE, eval=FALSE-----------------------------------
# 
# # sample based (using data from fit) conditional treatment effects, posterior draws
# posterior_treat_eff_on_treated <-
#   treatment_effects(te_model, treatment = "z", newdata = dat, subset = "treated")
# 

## ----cates-hist-treated, echo=TRUE, cache=FALSE-------------------------------

posterior_treat_eff_on_treated %>% 
  ggplot() +
  geom_histogram(aes(x = cte), binwidth = 0.1, colour = "white") + 
  theme_bw() + ggtitle("Histogram of treatment effect (all draws from treated subjects)")


## ----cates-stack-plot, echo=TRUE, cache=FALSE---------------------------------

posterior_treat_eff %>% select(-z) %>% point_interval() %>%
  arrange(cte) %>% mutate(.orow = 1:n()) %>% 
  ggplot() + 
  geom_interval(aes(x = .orow, y= cte, ymin = .lower, ymax = .upper)) +
  geom_point(aes(x = .orow, y = cte), shape = "circle open", alpha = 0.5) + 
  ylab("Median posterior CATE for each subject (95% CI)") +
  theme_bw() + coord_flip() + scale_colour_brewer() +
  theme(axis.title.y = element_blank(), 
        axis.text.y = element_blank(), 
        axis.ticks.y = element_blank(),
        legend.position = "none")


## ----cates-line-plot, echo=TRUE, cache=FALSE----------------------------------

posterior_treat_eff %>%
  left_join(tibble(c1 = dat$c1, .row = 1:length(dat$c1) ), by = ".row") %>%
  group_by(c1) %>%
  ggplot() + 
  stat_halfeye(aes(x = c1, y = cte), alpha = 0.7) +
  scale_fill_brewer() +
  theme_bw() + ggtitle("Treatment effect by `c1`")



## ----common-support, echo=TRUE, results='hide', cache=FALSE-------------------

# NOTE: common support calculated on linear scale (not probability scale for probit/logit)

# calculate common support directly
# argument 'modeldata' must be specified for BART models 
csupp_chisq <- has_common_support(te_model, treatment = "z", modeldata = dat,
                             method = "chisq", cutoff = 0.05)

csupp_chisq %>% filter(!common_support)

csupp_sd <- has_common_support(te_model, treatment = "z", modeldata = dat,
                             method = "sd")
csupp_sd %>% filter(!common_support)

# calculate treatment effects (on those who were treated) 
# and include only those estimates with common support
posterior_treat_eff_on_treated <- 
  treatment_effects(te_model, treatment = "z", subset = "treated", newdata = dat,
                    common_support_method = "sd") 


## ----interaction-investigator, echo=TRUE, cache=FALSE-------------------------

  treatment_interactions <-
    covariate_with_treatment_importance(te_model, treatment = "z")

  treatment_interactions %>% 
    ggplot() + 
    geom_bar(aes(x = variable, y = avg_inclusion), stat = "identity") +
    theme_bw() + ggtitle("Important variables interacting with treatment ('z')") + ylab("Inclusion counts") +
  theme(axis.text.x = element_text(angle = 45, hjust=1))
  
  variable_importance <-
    covariate_importance(te_model)

  variable_importance %>% 
    ggplot() + 
    geom_bar(aes(x = variable, y = avg_inclusion), stat = "identity") +
    theme_bw() + ggtitle("Important variables overall") +
    ylab("Inclusion counts") +
    theme(axis.text.x = element_text(angle = 45, hjust=1))
  

## ----sigma-trace, echo=TRUE, cache=FALSE--------------------------------------

# includes skipped MCMC samples
variance_draws(te_model, value = "siqsq") %>% 
  filter(.draw > 10000) %>%
  ggplot(aes(x = .draw, y = siqsq)) +
  geom_line() +
  theme_bw() + 
  ggtitle("Trace plot of model variance post warm-up")


## ----convergence-bart, echo=TRUE, cache=FALSE---------------------------------

res <- residual_draws(te_model, response = pull(dat, y), include_newdata = FALSE)
res %>%   
  point_interval(.residual, y, .width = c(0.95) ) %>%
  select(-y.lower, -y.upper) %>%
  ggplot() + 
  geom_pointinterval(aes(x = y, y = .residual, ymin = .residual.lower,  ymax = .residual.upper), alpha = 0.2) +
  scale_fill_brewer() +
  theme_bw() + ggtitle("Residuals vs observations")

res %>% summarise(.fitted = mean(.fitted), y = first(y)) %>% 
  ggplot(aes(x = y, y = .fitted)) +
  geom_point() + 
  geom_smooth(method = "lm") + 
  theme_bw() + ggtitle("Observations vs fitted")

res %>% summarise(.residual = mean(.residual)) %>%
  ggplot(aes(sample = .residual)) + 
  geom_qq() + 
  geom_qq_line() + 
  theme_bw() + ggtitle("Q-Q plot of residuals")


