# main functions ==========

#' Stack a ggplot
#'
#' Use `ggstackplot()` to generate a stackplot. If you need more fine control, use `prepare_stackplot()` and `assemble_stackplot()` individually. To explore examples of all the different features, check out the `vignette("explore", "ggstackplot")` or the [online documentation](https://ggstackplot.kopflab.org/articles/explore.html).
#'
#' @details
#' `ggstackplot()` stacks a ggplot template with the provided data and parameters. It returns a plot object generated by [cowplot::plot_grid()]).
#'
#' @param data the data frame to plot
#' @param x the x variable(s) to plot, accepts [dplyr::select()] syntax. The order of variables is plotted from left to right (if multiple `x`).
#' @param y the y variable(s) to plot, accepts [dplyr::select()] syntax. The order of variables in plotted from top to bottom (if multiple `y`).
#' @param remove_na whether to remove `NA` values in the x/y plot, setting this to `FALSE` can lead to unintended side-effects for interrupted lines so check your plot carefully if you change this
#' @param color which color to make the plots (also sets the plotwide color and fill aesthetics, overwrite in individual geoms in the `template` to overwrite this aesthetic), either one value for or one color per variable. Pick `NA` to not set colors (in case you want to use them yourself in the aesthetics).
#' @param palette which color to make the plots defined with an RColorBrewer palette ([RColorBrewer::display.brewer.all()]). You can only use `color` or `palette` parameter, not both.
#' @param both_axes whether to have the stacked axes on both sides (overrides alternate_axes and switch_axes)
#' @param alternate_axes whether to alternate the sides on which the stacked axes are plotted
#' @param switch_axes whether to switch the stacked axes. Not switching means that for vertical stacks the plot at the bottom has the y-axis always on the left side; and for horizontal stacks that the plot on the left has the x-axis on top. Setting `switch_axes = TRUE`, leads to the opposite.
#' @param overlap fractional overlap between adjacent plots. The max of 1 means plots are perfectly overlaid. The min of 0 means there is no overlap. If providing multiple values, must be 1 less than the number of stacked plots (since it's describing the overlap/gap between adjacent plots). By default there is no overlap between plots
#' @param simplify_shared_axis whether to simplify the shared axis to only be on the last plot (+ first plot if a duplicate secondary axis is set)
#' @param shared_axis_size if simplify_shared_axes is true, this determines the size of the shared axis relative to the size of a single plot
#' @param template a template plot (ggplot object) to use for the stacked plots
#' @param add a list of ggplot component calls to add to specific panel plots, either by panel variable name (named list) or index (unnamed list)
#' @param debug `r lifecycle::badge("experimental")` debug flag to print the stackplot tibble and gtable intermediates
#' @examples
#' # 1 step stackplot (most common use)
#' mtcars |>
#'   ggstackplot(
#'     x = mpg,
#'     y = c(`weight [g]` = wt, qsec, drat, disp),
#'     palette = "Set1",
#'     overlap = c(1, 0, 0.3)
#'   )
#'
#' # 2 step stackplot
#' mtcars |>
#'   prepare_stackplot(
#'     x = mpg,
#'     y = c(`weight [g]` = wt, qsec, drat, disp),
#'     palette = "Set1"
#'   ) |>
#'   assemble_stackplot(overlap = c(1, 0, 0.3))
#'
#' @examplesIf interactive()
#' # many more examples available in the vignette
#' vignette("ggstackplot")
#'
#' @export
#' @returns `ggstackplot()` returns a ggplot with overlayed plot layers
ggstackplot <- function(
    data, x, y, remove_na = TRUE, color = NA, palette = NA,
    both_axes = FALSE, alternate_axes = TRUE, switch_axes = FALSE,
    overlap = 0, simplify_shared_axis = TRUE, shared_axis_size = 0.2,
    template = ggplot() +
      geom_line() +
      geom_point() +
      theme_stackplot(),
    add = list(),
    debug = FALSE) {

  # put everything together
  data |>
    prepare_stackplot(
      x = {{ x }}, y = {{ y }},
      remove_na = remove_na, color = color, palette = palette,
      both_axes = both_axes, alternate_axes = alternate_axes,
      switch_axes = switch_axes, template = template, add = {{ add }},
      debug = debug) |>
    assemble_stackplot(
      overlap = overlap,
      simplify_shared_axis = simplify_shared_axis,
      shared_axis_size = shared_axis_size,
      debug = debug
    )
}

#' Prepare the stackplot
#'
#' @details
#' `prepare_stackplot()` is usually not called directly but can be used to assemble the parts of a stackplot first and then look at them or edit them individually before combining them with `assemble_stackplot()]`. Returns a nested data frame with all stacked variables (.var), their plot configuration, data, plot object, and theme object.
#' @rdname ggstackplot
#' @export
#' @returns `prepare_stackplot()` returns a tibble with all plot components
prepare_stackplot <- function(
    data, x, y, remove_na = TRUE, color = NA, palette = NA,
    both_axes = FALSE, alternate_axes = TRUE, switch_axes = FALSE,
    template = ggplot() +
      geom_line() +
      geom_point() +
      theme_stackplot(),
    add = list(),
    debug = FALSE) {

  # prep the stackplot
  prepared_stackplot <- data |>
    # prepare plotting data
    create_stackplot_tibble(
      x = {{ x }}, y = {{ y }},
      remove_na = remove_na, color = color,
      palette = palette,
      both_axes = both_axes,
      alternate_axes = alternate_axes,
      switch_axes = switch_axes
    ) |>
    # prepare plots
    dplyr::mutate(plot = map2(.data$config, .data$data, make_plot, template)) |>
    # prepare themes
    dplyr::mutate(theme = map(.data$config, make_color_axis_theme)) |>
    # process add ons
    process_add_ons(add = {{ add }})

  # debug
  if (debug) {
    rlang::inform("\n[DEBUG] stackplot tibble")
    print(
      prepared_stackplot |>
      dplyr::select(".var", "config", "data") |>
      tidyr::unnest("config")
    )
  }

  # return
  return(prepared_stackplot)
}

# internal function to prepare the data for a ggstackplot
create_stackplot_tibble <- function(
    data, x, y, remove_na = TRUE, color = NA, palette = NA, both_axes = FALSE, alternate_axes = FALSE, switch_axes = FALSE, call = caller_env()) {

  # do we have a data frame?
  if (missing(data) || !is.data.frame(data)) {
    cli_abort("`data` must be a data frame or tibble.", call = call)
  }

  # do x and y evaluate correctly?
  x <- try_fetch(
    tidyselect::eval_select(rlang::enexpr(x), data),
    error = function(cnd) {
      cli_abort(
        "`x` must be a valid tidyselect expression.",
        parent = cnd, call = call
      )
    }
  )
  y <- try_fetch(
    tidyselect::eval_select(rlang::enexpr(y), data),
    error = function(cnd) {
      cli_abort(
        "`y` must be a valid tidyselect expression.",
        parent = cnd, call = call
      )
    }
  )

  # do we have at least 1 x and 1 y?
  if (length(x) < 1 || length(y) < 1) {
    cli_abort(c(
      "insufficient number of columns",
      "x" = if (length(x) < 1) "no `x` column selected",
      "x" = if (length(y) < 1) "no `y` column selected"
    ), call = call)
  }
  # do we have both multiple x AND y?
  if (length(x) > 1 && length(y) > 1) {
    cli_abort(c(
      "too many columns, only x OR y can select multiple columns",
      "x" = if (length(x) < 1) "no `x` column selected",
      "x" = if (length(y) < 1) "no `y` column selected"
    ), call = call)
  }

  # do we have valid remove_na, both_axes, alternate_axes, and switch_axes (the booleans)
  stopifnot(
    "`remove_na` must be TRUE or FALSE" = is_bool(remove_na),
    "`both_axes` must be TRUE or FALSE" = is_bool(both_axes),
    "`alternate_axes` must be TRUE or FALSE" = is_bool(alternate_axes),
    "`switch_axes` must be TRUE or FALSE" = is_bool(switch_axes)
  )

  # determine direction
  direction <- if (length(x) > 1) "horizontal" else "vertical"

  # prep data (pivot based on direction)
  data_long <-
    if (direction == "horizontal") {
      data |>
        dplyr::rename(dplyr::all_of(x), dplyr::all_of(y)) |>
        tidyr::pivot_longer(cols = dplyr::all_of(names(x)), names_to = ".var", values_to = ".x") |>
        dplyr::mutate(.y = !!sym(names(!!y)[1]))
    } else {
      data |>
        dplyr::rename(dplyr::all_of(x), dplyr::all_of(y)) |>
        tidyr::pivot_longer(cols = dplyr::all_of(names(y)), names_to = ".var", values_to = ".y") |>
        dplyr::mutate(.x = !!sym(names(!!x)[1]))
    }

  # remove na
  if (remove_na) data_long <- data_long |> dplyr::filter(!is.na(.data$.x), !is.na(.data$.y))

  # prep config
  config <- dplyr::tibble(
    .xvar = factor_in_order(names(x)),
    .yvar = factor_in_order(names(y))
  ) |>
    dplyr::arrange(.data$.xvar, .data$.yvar)

  # do we have a valid length for color or palette?
  stopifnot("can only set either `color` or `palette`, not both" = is.na(color) | is.na(palette))
  if (!(is.character(color) || all(is.na(color))) || !length(color) %in% c(1L, nrow(config))) {
    cli_abort(sprintf("`color` must be either a single color or one for each variable (%d)", nrow(config)), call = call)
  }
  if (!all(is.na(palette))) {
    # palette argument provided
    if (is_scalar_character(palette) && palette %in% rownames(RColorBrewer::brewer.pal.info) && RColorBrewer::brewer.pal.info[palette, 1] >= nrow(config)) {
      color = RColorBrewer::brewer.pal(RColorBrewer::brewer.pal.info[palette, 1], palette)[1:nrow(config)]
    } else
      sprintf("`palette` must be a string identifying a valid RColorBrewer palette with at least %d colors. Use `RColorBrewer::display.brewer.all()` to see all available palettes.", nrow(config)) |>
      cli_abort(call = call)
  }


  # finish config
  config <- config |>
    dplyr::mutate(
      .color = !!color,
      .axis_switch =
        if (both_axes) NA else
          calculate_axis_switch(
            # Note: the reverse_factor and reverse = TRUE for 'vertical'
            # plot are both needed to properly invert the order AND keep
            # the first plot in the lower left (unless switch = TRUE)
            var =
              if (!!direction == "vertical")
                reverse_factor(.data$.yvar)
              else .data$.xvar,
            alternate = {{ alternate_axes }},
            switch = {{ switch_axes }},
            reverse = !!direction == "vertical"
          ),
      .shared_axis_min =
        if (!!direction == "horizontal") min(data_long$.y, na.rm = TRUE)
        else min(data_long$.x, na.rm = TRUE),
      .shared_axis_max =
        if (!!direction == "horizontal") max(data_long$.y, na.rm = TRUE)
        else max(data_long$.x, na.rm = TRUE),
      .first =
        (direction == "horizontal" & as.integer(.data$.xvar) == 1L) |
          (direction == "vertical" & as.integer(.data$.yvar) == 1L),
      .last =
        (direction == "horizontal" & as.integer(.data$.xvar) == length(levels(.data$.xvar))) |
          (direction == "vertical" & as.integer(.data$.yvar) == length(levels(.data$.yvar))),
      .var = if (direction == "horizontal") .data$.xvar else .data$.yvar,
      .direction = !!direction
    )

  # complete prepped data
  return(
    config |>
      tidyr::nest(config = -".var") |>
      dplyr::left_join(
        tidyr::nest(data_long, data = -".var"),
        by = ".var"
      )
  )
}

#' Combine the stack plot
#'
#' @details
#' `assemble_stackplot()` is usually not called directly but can be used to manually combine a stackplot tibble (typically created by `prepare_stockplot()`). Returns a plot object generated by [cowplot::plot_grid()]).
#'
#' @param prepared_stackplot a nested data frame, the output from [prepare_stackplot()]
#' @rdname ggstackplot
#' @export
#' @returns `assemble_stackplot()` returns a ggplot with overlayed plot layers
assemble_stackplot <- function(prepared_stackplot, overlap = 0, simplify_shared_axis = TRUE, shared_axis_size = 0.15, debug = FALSE) {

  # assemble the stackplot
  gtables <-
    prepared_stackplot |>
    create_stackplot_gtables(
      overlap = overlap,
      simplify_shared_axis = simplify_shared_axis,
      shared_axis_size = shared_axis_size
    )

  # debug
  if (debug) {
    rlang::inform("\n[DEBUG] stackplot gtables")
    print(
      gtables |>
        dplyr::select(-"gtable", -"plot_w_theme")
    )
  }

  return(combine_gtables(gtables))
}

# internal function to great a list of gtables for the combined plot
create_stackplot_gtables <- function(prepared_stackplot, overlap, simplify_shared_axis, shared_axis_size, call = caller_env()) {

  # do we have a data frame?
  req_cols <- c(".var", "config", "data", "plot", "theme")
  if (missing(prepared_stackplot) || !is.data.frame(prepared_stackplot) ||
      !all(req_cols %in% names(prepared_stackplot))) {
    cli_abort(
      "{.var prepared_stackplot} must be a data frame or tibble with columns
      {.emph {req_cols}}", call = call
    )
  }

  # do we have a valid overlap value?
  if (missing(overlap) || !is.numeric(overlap) || !all(overlap >= 0) || !all(overlap <= 1) ||
      !length(overlap) %in% c(1L, nrow(prepared_stackplot) - 1L)) {
    cli_abort(
      c("{.var overlap} must be either a single numeric value (between 0 and 1)
      or a vector with {nrow(prepared_stackplot) - 1L} numbers, one for the
      overlap of each sequential plot",
      "x" = "{.var overlap} is a {.obj_type_friendly {overlap}}"),
      call = call)
  }

  # combine plots and themes and assemble the gtables
  gtables <- prepared_stackplot |>
    combine_plot_theme_add(simplify_shared_axis = simplify_shared_axis, include_adds = TRUE) |>
    tidyr::unnest("config") |>
    dplyr::select(".var", ".direction", "plot_w_theme")

  # make sure horizontal panels are in the correct order
  # (reverse since horizontal positioning is inverted relative to vertical
  # given the combined plot coordinate system starts in lower left corner)
  if(gtables$.direction[1] == "horizontal") {
    gtables <- gtables |> dplyr::arrange(dplyr::desc(dplyr::row_number()))
    overlap <- rev(overlap)
  }

  gtables <- gtables |>
    # could think about relative sizing here with size_adjust but that doesn't seem like a feature we need
    dplyr::mutate(
      size = 1,
      size_adjust = 0,
      pos_adjust = 0,
      overlap =
        if(length(!!overlap) == 1L) c(0, rep(!!overlap, dplyr::n() - 1L))
        else c(0, !!overlap),
      gtable = map(.data$plot_w_theme, ggplot2::ggplotGrob)
    )

  # shared axis simplification?
  if (simplify_shared_axis) {
    # x axes (could get these from any of the pre-final plots)
    shared_axis_plot <- prepared_stackplot[1,] |>
      combine_plot_theme_add(simplify_shared_axis = FALSE, include_adds = FALSE)

    # primary axis present?
    primary_axis_components <-
      if(gtables$.direction[1] == "horizontal") c("axis-l", "ylab-l")
      else c("axis-b", "xlab-b")
    primary_axis <- get_plot_component_grobs(
      shared_axis_plot$plot_w_theme[[1]],
      .data$name %in% primary_axis_components
    )
    has_primary_axis <- !all(is_zero_grob(primary_axis))

    # secondary axis present?
    secondary_axis_components <-
      if(gtables$.direction[1] == "horizontal") c("axis-r", "ylab-r")
      else c("axis-t", "xlab-t")
    secondary_axis <- get_plot_component_grobs(
      shared_axis_plot$plot_w_theme[[1]],
      .data$name %in% secondary_axis_components
    )
    has_secondary_axis <- !all(is_zero_grob(secondary_axis))
    shared_axis_adjust <- 1 + (has_primary_axis & has_secondary_axis) * shared_axis_size

    # account for primary axis
    if(has_primary_axis) {
      gtables <- dplyr::bind_rows(
        gtables,
        dplyr::tibble(
          .var = "primary",
          .direction = gtables$.direction[1],
          size = shared_axis_size,
          size_adjust = shared_axis_adjust,
          pos_adjust = 0,
          overlap = 0,
          pos = .data$size - .data$overlap,
          gtable = list(primary_axis)
        )
      )
    }

    # account for secondary axis
    if(has_secondary_axis) {
      gtables <- dplyr::bind_rows(
        dplyr::tibble(
          .var = "secondary",
          .direction = gtables$.direction[1],
          size = shared_axis_size,
          size_adjust = shared_axis_adjust,
          pos_adjust = shared_axis_adjust,
          overlap = 0,
          pos = .data$size - .data$overlap,
          gtable = list(secondary_axis)
        ),
        gtables
      )
    }
  }

  # calculations
  gtables <- gtables |>
    align_gtables() |>
    dplyr::mutate(
      pos = cumsum(.data$size) - cumsum(.data$overlap),
      total_size = sum(.data$size) - sum(.data$overlap),
      rel_pos = 1 - (.data$pos + .data$pos_adjust) / .data$total_size,
      rel_size = (.data$size + .data$size_adjust) / .data$total_size,
      x = if(.data$.direction[1] == "horizontal") .data$rel_pos else 0,
      y = if(.data$.direction[1] == "horizontal") 0 else .data$rel_pos,
      width = if(.data$.direction[1] == "horizontal") .data$rel_size else 1,
      height = if(.data$.direction[1] == "horizontal") 1 else .data$rel_size
    )

  return(gtables)
}
