# Copyright © 2022 - 2025 Rnaught contributors
#
# This file is part of Rnaught.
#
# Rnaught is free software: you can redistribute it and/or modify it under the
# terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# Rnaught is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
# details.
#
# You should have received a copy of the GNU Affero General Public License along
# with Rnaught. If not, see <https://www.gnu.org/licenses/>.


#' Sequential Bayes (seqB)
#'
#' This function implements a sequential Bayesian estimation method of R0 due to
#' Bettencourt and Ribeiro (PloS One, 2008). See details for important
#' implementation notes.
#'
#' The method sets a uniform prior distribution on R0 with possible values
#' between `0` and `kappa`, discretized to a fine grid. The distribution of R0
#' is then updated sequentially, with one update for each new case count
#' observation. The final estimate of R0 is the mean of the (last) posterior
#' distribution. The prior distribution is the initial belief of the
#' distribution of R0, which is the uninformative uniform distribution with
#' values between `0` and `kappa`. Users can change the value of `kappa` only
#' (i.e., the prior distribution cannot be changed from the uniform). As more
#' case counts are observed, the influence of the prior distribution should
#' lessen on the final estimate.
#'
#' This method is based on an approximation of the SIR model, which is most
#' valid at the beginning of an epidemic. The method assumes that the mean of
#' the serial distribution (sometimes called the serial interval) is known. The
#' final estimate can be quite sensitive to this value, so sensitivity testing
#' is strongly recommended. Users should be careful about units of time (e.g.,
#' are counts observed daily or weekly?) when implementing.
#'
#' Our code has been modified to provide an estimate even if case counts equal
#' to zero are present in some time intervals. This is done by grouping the
#' counts over such periods of time. Without grouping, and in the presence of
#' zero counts, no estimate can be provided.
#'
#' @param cases Vector of case counts. The vector must only contain non-negative
#'   integers, and have at least two positive integers.
#' @param mu Mean of the serial distribution. This must be a positive number.
#'   The value should match the case counts in time units. For example, if case
#'   counts are weekly and the serial distribution has a mean of seven days,
#'   then `mu` should be set to `1`. If case counts are daily and the serial
#'   distribution has a mean of seven days, then `mu` should be set to `7`.
#' @param kappa Largest possible value of the uniform prior (defaults to `20`).
#'   This must be a number greater than or equal to `1`. It describes the prior
#'   belief on the ranges of R0, and should be set to a higher value if R0 is
#'   believed to be larger.
#' @param post Whether to return the posterior distribution of R0 instead of the
#'   estimate of R0 (defaults to `FALSE`). This must be a value identical to
#'   `TRUE` or `FALSE`.
#'
#' @return If `post` is identical to `TRUE`, a list containing the following
#'   components is returned:
#'   * `supp` - the support of the posterior distribution of R0
#'   * `pmf` - the probability mass function of the posterior distribution of R0
#'
#'   Otherwise, if `post` is identical to `FALSE`, only the estimate of R0 is
#'   returned. Note that the estimate is equal to `sum(supp * pmf)` (i.e., the
#'   posterior mean).
#'
#' @references Bettencourt and Ribeiro (PloS One, 2008)
#'   \doi{doi:10.1371/journal.pone.0002185}
#'
#' @seealso `vignette("seq_bayes_post", package = "Rnaught")` for examples of
#'   using the posterior distribution.
#'
#' @export
#'
#' @examples
#' # Weekly data.
#' cases <- c(1, 4, 10, 5, 3, 4, 19, 3, 3, 14, 4)
#'
#' # Obtain R0 when the serial distribution has a mean of five days.
#' seq_bayes(cases, mu = 5 / 7)
#'
#' # Obtain R0 when the serial distribution has a mean of three days.
#' seq_bayes(cases, mu = 3 / 7)
#'
#' # Obtain R0 when the serial distribution has a mean of seven days, and R0 is
#' # believed to be at most 4.
#' estimate <- seq_bayes(cases, mu = 1, kappa = 4)
#'
#' # Same as above, but return the posterior distribution of R0 instead of the
#' # estimate.
#' posterior <- seq_bayes(cases, mu = 1, kappa = 4, post = TRUE)
#'
#' # Display the support and probability mass function of the posterior.
#' posterior$supp
#' posterior$pmf
#'
#' # Note that the following always holds:
#' estimate == sum(posterior$supp * posterior$pmf)
seq_bayes <- function(cases, mu, kappa = 20, post = FALSE) {
  validate_cases(cases, min_length = 2, min_count = 0)
  if (!is_real(mu) || mu <= 0) {
    stop("The serial interval (`mu`) must be a number greater than 0.",
      call. = FALSE
    )
  }
  if (!is_real(kappa) || kappa < 1) {
    stop(
      paste("The largest value of the uniform prior (`kappa`)",
        "must be a number greater than or equal to 1."
      ), call. = FALSE
    )
  }
  if (!identical(post, TRUE) && !identical(post, FALSE)) {
    stop("The posterior flag (`post`) must be set to `TRUE` or `FALSE`.",
      call. = FALSE
    )
  }

  times <- which(cases > 0)
  if (length(times) < 2) {
    stop("Case counts must contain at least two positive integers.",
      call. = FALSE
    )
  }
  cases <- cases[times]

  support <- seq(0, kappa, 0.01)
  tau <- diff(times)

  prior <- rep(1, kappa / 0.01 + 1)
  prior <- prior / sum(prior)
  posterior <- seq(0, length(prior))

  for (i in seq_len(length(cases) - 1)) {
    lambda <- tau[i] / mu * (support - 1) + log(cases[i])
    log_like <- cases[i + 1] * lambda - exp(lambda)
    max_log_like <- max(log_like)

    if (max_log_like > 700) {
      log_like <- log_like - max_log_like + 700
    }

    posterior <- exp(log_like) * prior
    posterior <- posterior / sum(posterior)
    prior <- posterior
  }

  if (!post) {
    return(sum(support * posterior))
  }
  list(supp = support, pmf = posterior)
}
