#' @title Function to calculate the two-sample generalized log-rank statistic for composite endpoint under sequential monitoring.
#' @description
#' Computes a two-sample generalized log-rank test statistic for composite endpoints consisting of recurrent events and
#' a terminal event, using data observed up to a given calendar time. Event times are converted from the calendar-time
#' scale to the event-time scale (time since enrollment). The statistics integrates the difference between estimated mean
#' frequency increments across groups with a risk-set based weight function.
#' @param data A data frame generated by \code{TwoSample.generate.sequential()} (optionally after applying \code{Apply.calendar.censoring.2()})
#' containing simulated two-sample composite endpoint data.
#' @param tau Positive numeric value specifying the upper bound of event time for the integration. Default is \code{3}.
#'
#' @returns A list with components:
#' \itemize{
#' \item \code{Q}: Value of the generalized log-rank statistics integrated over \code{[0, tau]}.
#' \item \code{var}: Estimated asymptotic variance of \code{Q}.
#' \item \code{const}: Scaling constant used in the variance estimation.
#' }
#' @export
#' @importFrom dplyr %>% group_by filter mutate count select slice ungroup
#' @importFrom tibble as_tibble
#' @importFrom bdsmatrix bdsBlock
#' @importFrom rlang .data
#'
#' @examples
#' # Two-sample generalized log-rank test: null hypothesis
#' df <- TwoSample.generate.sequential(sizevec = c(200, 200),
#' beta.trt = 0, calendar = 5, recruitment = 3,
#' random.censor.rate = 0.05, seed = 2026)
#' TwoSample.Estimator.LR.sequential(data = df, tau = 3)
#' # Two-sample generalized log-rank test: alternative hypothesis
#' df2 <- TwoSample.generate.sequential(sizevec = c(200, 200),
#' beta.trt = 0.8, calendar = 5, recruitment = 3,
#' random.censor.rate = 0.05, seed = 2026)
#' TwoSample.Estimator.LR.sequential(data = df2, tau = 3)
TwoSample.Estimator.LR.sequential <- function(data, tau = 3){

  original.data <- data %>%
    dplyr::group_by(.data$id) %>%
    dplyr::filter(!is.na(.data$status)) %>%
    dplyr::mutate(true_event_time = .data$event_time_cal - .data$e)

  ns <- c(NA, NA) #group sizes
  # unsorted all times (recurrent, death and censoring)
  all.time <- original.data$true_event_time
  sorted.all.time <- sort(all.time)
  Ybars <- dmuhats <- matrix(NA, 2, length(sorted.all.time))
  dPsihats <- vector(mode = "list", length = 2)
  # tau <- 3 # Upper bound of event time
  time.idx <- vector(mode = "list", length = 2)
  truncate.idxs <- c(NA, NA)

  for (a in 1:2){
    # a <- 2
    # sort all event times (recurrent, death, and censoring)
    data <- original.data[original.data$group == a,]
    ns[a] <- length(unique(data$id))
    # data_new <- data[order(data$time),]
    data_new <- data[order(data$true_event_time),]

    # All event times, including recurrent, death and censoring
    # sorted.time <- data_new$time
    sorted.time <- data_new$true_event_time
    sorted.event <- data_new$event
    n <- length(unique(data_new$id)) # sample size
    L <- length(sorted.event) # total number of all events (recurrent, death and censoring)

    # save the group times index of combined times
    time.idx[[a]] <- match(sorted.time, sorted.all.time)

    # last observation for each subject, death or censoring
    # last.time <- data_new$time[data_new$status == 1 | data_new$status == 0]
    # last.time.unsorted <- data$time[data$status == 1 | data$status == 0]
    last.time <- data_new$true_event_time[data_new$status == 1 | data_new$status == 0]
    last.time.unsorted <- data$true_event_time[data$status == 1 | data$status == 0]
    last.time.id <- match(last.time, last.time.unsorted)

    # At risk process for each event
    Y <- 1*(matrix(rep(last.time, L), n, L) >= matrix(rep(sorted.time, each = n), n, L))

    # Kaplan Meier estimates for death time points
    death <- data_new$death
    KMhat <- cumprod(1 - death/colSums(Y))

    # Nelson-Aalen estimates for all dN(t) = 1 (recurrent and death)
    dRhat <- sorted.event/colSums(Y)
    Rhat <- cumsum(dRhat)

    # Mean frequency estimator
    dmuhat <- KMhat*dRhat
    muhat <- cumsum(dmuhat)
    imuhat <- stats::stepfun(sorted.time, c(0, muhat))
    muhat.extended <- imuhat(sorted.all.time)
    dmuhats[a,] <- diff(c(0,muhat.extended))

    ##### variance estimator ###
    ## The below calculation will use original id order, not the sorted last observation time order##
    Y <- Y[order(last.time.id),]
    Ybar <- colSums(Y)
    # Ybar.temp <- colSums(Y)
    # Ybar <- Ybar.temp + 1*(Ybar.temp == 0)
    # Qinghua 2/25/2025 update: The above lines are a 'safety net' to prevent the number at risk from being zero,
    # I don't think they have ever been activated.
    # This should be fine for the one sample estimator, however, for the two sample logrank statistics, since the
    # weight function is a function of Ybar, when both groups' Ybar are zero, the weight function should be zero, too.

    iYbar <- stats::stepfun(sorted.time, c(n, Ybar))
    Ybars[a,] <- iYbar(sorted.all.time)

    # cumulative hazard for death
    delta <- data[data$status == 1|data$status == 0,]$death
    ND <- 1*(matrix(rep(last.time.unsorted, L), n, L) <= matrix(rep(sorted.time, each = n), n, L))*delta
    dND <- t(apply(ND, 1, function(x) diff(c(0,x))))
    dlambdaDhat <- colSums(t(t(dND)/Ybar))
    lambdaDhat <- cumsum(dlambdaDhat)
    # intensity process for death
    dADhat <- t(apply(Y, 1, function(x) x*dlambdaDhat))
    dMDhat <- dND - dADhat

    # number of events per subject, will be used to create a block diag matrix
    id.size <- data.frame(data %>% group_by(id) %>% count())$n
    grp <- bdsBlock(1:L, rep(1:n, id.size)) # block diag matrix
    grp <- as.matrix(grp)
    # original.time <- data$time
    original.time <- data$true_event_time
    t1 <- matrix(rep(original.time, each = L), L, L)*grp
    # put each subject's all event times on block diag, t1 is L x L
    t2 <- unique(t1) # keep only one row per subject, t2 is n x L
    original.event <- data$event # "event' is dN(t) , sum of 'recurrent' and 'death'

    t3 <- matrix(rep(original.event, each = L), L, L)*grp
    # put each subject's all event indicator on block diag, t3 is L x L
    t4 <- as.matrix(cbind(id = data$id, t3) %>% as_tibble() %>% group_by(id) %>%
                      slice(n()) %>% ungroup() %>% select(-id))
    # keep only one row per subject, t4 is n x L

    t5 <- t2*t4 # make censoring times become zero, since dN(t) = 0 when censored
    t6 <- unname(t(apply(t5, 1, function(x) x[order(original.time)])))
    # out each subjects' dN(t) = 1 (recurrent and death) times in the sorted order
    dN <- 1*(t6 == matrix(rep(sorted.time, each = n), n, L)) # at which time point did dN(t) jumped

    N <- t(apply(dN, 1, cumsum))
    dAhat <- t(apply(Y, 1, function(x) x*dRhat))
    dMhat <- dN - dAhat

    dpartI <- t(apply(dMhat, 1, function(x) x*KMhat/(Ybar/n)))

    dpartII <- t(apply(dMDhat, 1, function(x) x*muhat/(Ybar/n)))

    dpartIII.1 <- t(apply(dMDhat, 1, function(x) x/(Ybar/n)))
    dpartIII.2 <- t(apply(dpartIII.1, 1, cumsum))
    dpartIII <- t(apply(dpartIII.2, 1, function(x) x*dmuhat))

    dpartIV <- t(apply(dMDhat, 1, function(x) x*muhat/(Ybar/n)))

    dPsihat <- dpartI - dpartII -dpartIII + dpartIV

    # integrate from zero to tau
    truncate.idxa <- max(which(sorted.time <= tau))
    dPsihats[[a]] <- dPsihat[,1:truncate.idxa]
    truncate.idxs[a] <- truncate.idxa
  }

  # weight function

  # Qinghua 2/25/2025 Update: set Khat to zero if both groups' Ybar are zero.
  if (all(Ybars[1,]==0) & all(Ybars[2,]==0)){
    Khat <- 0
  } else{
    Khat <- (sum(ns)/prod(ns))*Ybars[1,]*Ybars[2,]/(Ybars[1,] + Ybars[2,])
  }

  # integrate from 0 to tau
  truncate.idx <- max(which(sorted.all.time <= tau))
  Q <- sum(Khat[1:truncate.idx]*(dmuhats[1, 1:truncate.idx] - dmuhats[2, 1:truncate.idx]))

  # Asymptotic variance
  Khat.grp1 <- Khat[time.idx[[1]]][1:truncate.idxs[1]]
  var.partI.1 <- t(apply(dPsihats[[1]], 1, function(x) x*Khat.grp1))
  var.partI.2 <- t(apply(var.partI.1, 1, sum))
  var.partI <- sum(var.partI.2^2)*ns[2]/(sum(ns)*ns[1])

  Khat.grp2 <- Khat[time.idx[[2]]][1:truncate.idxs[2]]
  var.partII.1 <- t(apply(dPsihats[[2]], 1, function(x) x*Khat.grp2))
  var.partII.2 <- t(apply(var.partII.1, 1, sum))
  var.partII <- sum(var.partII.2^2)*ns[1]/(sum(ns)*ns[2])


  const <- 1/sqrt(ns[1]*ns[2]/sum(ns))
  var <- (var.partI + var.partII)*const^2

  return(list(Q = Q, var = var, const = const))
}
