#' Fit exponential models to incidence data
#'
#' The function \code{fit} fits two exponential models to incidence data, of the
#' form: \eqn{log(y) = r * t + b} \cr where 'y' is the incidence, 't' is time
#' (in days), 'r' is the growth rate, and 'b' is the origin. The function
#' \code{fit} will fit one model by default, but will fit two models on either
#' side of a splitting date (typically the peak of the epidemic) if the argument
#' \code{split} is provided. The function \code{fit_optim_split} can be used to
#' find the optimal 'splitting' date, defined as the one for which the best
#' average R2 of the two models is obtained. Plotting can be done using
#' \code{plot}, or added to an existing incidence plot by the piping-friendly
#' function \code{add_incidence_fit}.
#'
#' @export
#'
#' @rdname fit
#'
#' @return For \code{fit}, a list with the class \code{incidence_fit} (for a
#' single model), or a list containing two \code{incidence_fit} objects (when
#' fitting two models). \code{incidence_fit} objects contain:
#'
#' \itemize{
#'  \item lm: the fitted linear model
#'
#'  \item info: a list containing various information extracted from the model
#' (detailed further)
#'
#'  \item origin: the date corresponding to day '0'
#' }
#'
#' The \code{$info} item is a list containing:
#'
#' \itemize{
#'  \item r: the growth rate
#'
#'  \item r.conf: the confidence interval of 'r'
#'
#'  \item pred: a \code{data.frame} containing predictions of the model,
#' including the true dates (\code{dates}), their numeric version used in the
#' model (\code{dates.x}), the predicted value (\code{fit}), and the lower
#' (\code{lwr}) and upper (\code{upr}) bounds of the associated confidence
#' interval.
#'
#'  \item doubling: the predicted doubling time in days; exists only if 'r' is
#' positive
#'
#'  \item doubling.conf: the confidence interval of the doubling time
#'
#'  \item halving: the predicted halving time in days; exists only if 'r' is
#' negative
#'
#'  \item halving.conf: the confidence interval of the halving time
#' }
#'
#' For \code{fit_optim_split}, a list containing:
#' \itemize{
#'
#'  \item df: a \code{data.frame} of dates that were used in the optimization
#' procedure, and the corresponding average R2 of the resulting models.
#'
#'  \item split: the optimal splitting date
#'
#'  \item fit: the resulting \code{incidence_fit} objects
#'
#'  \item plot: a plot showing the content of \code{df} (ggplot2 object)
#' }
#'
#' @author Thibaut Jombart \email{thibautjombart@@gmail.com}
#'
#' @seealso the \code{\link{incidence}} function to generate the 'incidence'
#' objects.
#'
#' @param x An incidence object, generated by the function
#' \code{\link{incidence}}. For the plotting function, an \code{incidence_fit}
#' object.
#'
#' @param split An optional time point identifying the separation between the
#' two models. If NULL, a single model is fitted. If provided, two models would
#' be fitted on the time periods on either side of the split.
#'
#' @param level The confidence interval to be used for predictions; defaults to
#' 95\%.
#'
#' @param quiet A logical indicating if warnings from \code{fit} should be
#' hidden; FALSE by default. Warnings typically indicate some zero incidence,
#' which are removed before performing the log-linear regression.
#'
#' @examples
#'
#' if (require(outbreaks)) {
#'  dat <- ebola_sim$linelist$date_of_onset
#'
#'  ## EXAMPLE WITH A SINGLE MODEL
#'
#'  ## compute weekly incidence
#'  i.7 <- incidence(dat, interval=7)
#'  plot(i.7)
#'  plot(i.7[1:20])
#'
#'  ## fit a model on the first 20 weeks
#'  f <- fit(i.7[1:20])
#'  f
#'  names(f)
#'  head(f$pred)
#'
#'  ## plot model alone (not recommended)
#'  plot(f)
#'
#'  ## plot data and model (recommended)
#'  plot(i.7, fit = f)
#'  plot(i.7[1:25], fit = f)
#'
#' ## piping versions
#' if (require(magrittr)) {
#'   plot(i.7) %>% add_incidence_fit(f)
#'
#'
#'   ## EXAMPLE WITH 2 PHASES
#'   ## specifying the peak manually
#'   f2 <- fit(i.7, split = as.Date("2014-10-15"))
#'   f2
#'   plot(i.7) %>% add_incidence_fit(f2)
#'
#'   ## finding the best 'peak' date
#'   f3 <- fit_optim_split(i.7)
#'   f3
#'   plot(i.7) %>% add_incidence_fit(f3$fit)
#' }
#' }
#'


## The model fitted is a simple linear regression on the log-incidence.

## Non-trivial bits involve:

## 1) Fitting several models
## I.e. in case there is a increasing and a decreasing phase, we fit one
##  model for each phase separately.

## 2) log(0)
## No satisfying solutions so far; for now removing the NAs

## 3) Several groups
## In this case, the number of models does not change, but models automatically
## include groups with interaction, whether or not it is significant.

## 4) Values of dates used as 'x'

## To retain generality, we need to use numbers (not Date or POSIXct) as 'x'
## axis for the model.  Therefore, all dates are expressed as numbers of days
## since the first case (aka 'day 0' or 'origin'), picking the middle of each
## time interval. We also keep track of the origin, so that actual dates can be
## reconstructed during the plotting. Each 'fit' object has its own origin.

fit <- function(x, split = NULL, level = 0.95, quiet = FALSE){
  n.groups <- ncol(x$counts)

  ## remove dates with one incidence of zero
  to.keep <- apply(x$counts, 1, min) > 0
  if (!quiet && !all(to.keep)) {
    warning(sprintf("%d dates with incidence of 0 ignored for fitting",
                    sum(!to.keep)))
  }
  x <- x[to.keep]
  ## If there is only one date with non-zero incidence
  ## then no model cannot be fit. If there are no days with
  ## non-zero incidence, creation of incidence object throws
  ## error anyway.
  if (x$timespan == 1) {
    stop("Only 1 date with non-zero incidence. Cannot fit model to 1 data point.")
  }
  ## model without split (1 model)
  if (is.null(split)) {
    df <- as.data.frame(x, long=TRUE)
    ## exact dates
    df$dates.x <- as.numeric(df$dates - min(df$dates)) + x$interval/2

    if (n.groups == 1) {
      lm1 <- stats::lm(log(counts) ~ dates.x, data = df)
    } else {
      lm1 <- stats::lm(log(counts) ~ dates.x * groups, data = df)
    }
    out <- extract_info(lm1, x, level)
  } else {
    x1 <- x[x$dates <= split]
    x2 <- x[x$dates >= split]
    df1 <- as.data.frame(x1, long=TRUE)
    df2 <- as.data.frame(x2, long=TRUE)
    ## exact dates
    df1$dates.x <- as.numeric(df1$dates - min(df1$dates)) + x$interval/2
    df2$dates.x <- as.numeric(df2$dates - min(df2$dates)) + x$interval/2
    if (n.groups == 1) {
      lm1 <- stats::lm(log(counts) ~  dates.x, data = df1)
      lm2 <- stats::lm(log(counts) ~  dates.x, data = df2)
    } else {
      lm1 <- stats::lm(log(counts) ~  dates.x * groups, data = df1)
      lm2 <- stats::lm(log(counts) ~  dates.x * groups, data = df2)
    }
    before <- extract_info(lm1, x1, level)
    after <- extract_info(lm2, x2, level)
    out <- list(before = before, after = after)
  }
  out
}





#' @export
#' @rdname fit
#'
#' @param window The size, in days, of the time window either side of the
#' split.
#'
#' @param plot A logical indicating whether a plot should be added to the
#' output, showing the mean R2 for various splits.
#'

fit_optim_split <- function(x, window = x$timespan/4, plot = TRUE,
                            quiet = TRUE){
  date.peak <- x$dates[which.max(x$counts[,1])] # !! this assumes a single group
  try.since <- date.peak - window / 2
  try.until <- date.peak + window / 2
  to.keep <- x$dates >= try.since & x$dates <= try.until
  if (sum(to.keep) < 1) {
    stop("No date left to try after defining splits to try.")
  }

  splits.to.try <- x$dates[to.keep]

  f <- function(split) {
    fits <- fit(x, split=split, quiet = quiet)
    mean(vapply(fits, function(e) summary(e$lm)$`adj.r.squared`, double(1)))
  }

  results <- vapply(splits.to.try, f, double(1))

  ## shape output
  df <- data.frame(dates = splits.to.try, mean.R2 = results)
  split <- splits.to.try[which.max(results)]
  fit <- suppressWarnings(fit(x, split = split))
  out <- list(df = df,
              split = split,
              fit = fit)

  if (plot) {
    out$plot <- ggplot2::ggplot(
      df, ggplot2::aes_string(x = "dates", y = "mean.R2")) +
      ggplot2::geom_point() + ggplot2::geom_line() +
      ggplot2::geom_text(ggplot2::aes_string(label="dates"),
                         hjust=-.1, angle=35) +
      ggplot2::ylim(min=min(results)-.1, max=1)
  }

  out
}





## Non-exported function extracting info and predictions from a lm object
## - reg is a lm object
## - x is an incidence object
## - level is a confidence level, defaulting to .95

extract_info <- function(reg, x, level){
  if (is.null(reg)) {
    return(NULL)
  }

  ## extract growth rates (r)
  ## here we need to keep all coefficients when there are interactions
  to.keep <- grep("^dates.x.*$", names(stats::coef(reg)), value=TRUE)
  r <- stats::coef(reg)[to.keep]
  use.groups <- length(r) > 1
  if (use.groups) {
    names(r) <- reg$xlevels[[1]] # names = levels if groups
  } else {
    names(r) <- NULL # no names otherwise
  }
  r.conf <- stats::confint(reg, to.keep, level)
  rownames(r.conf) <- names(r)
  if (use.groups) {
    r[-1] <- r[-1] + r[1] # add coefs to intercept
    r.conf[-1,] <- r.conf[-1,] + r.conf[1,] # add coefs to intercept
  }


  ## need to pass new data spanning all dates and groups here
  if (use.groups) {
    new.data <- expand.grid(sort(unique(reg$model$dates.x)),
                            levels(reg$model$groups))
    names(new.data) <- c("dates.x", "groups")
  } else {
    new.data <- data.frame(dates.x = sort(unique(reg$model$dates.x)))
  }
  pred <- exp(stats::predict(reg, newdata = new.data, interval = "confidence",
                             level = level))
  ## keep track of dates and groups for plotting
  pred <- cbind.data.frame(new.data, pred)
  info <- list(r = r, r.conf = r.conf,
               pred = pred)

  if (r[1] > 0 ) { # note: choice of doubling vs halving only based on 1st group
    info$doubling <- log(2) / r
    info$doubling.conf <- log(2) / r.conf
    o.names <- colnames(info$doubling.conf)
    info$doubling.conf <-info$doubling.conf[, rev(seq_along(o.names)),
                                            drop=FALSE]
    colnames(info$doubling.conf) <- o.names
  } else {
    info$halving <- log(0.5) / r
    info$halving.conf <- log(0.5) / r.conf
  }

  ## We need to store the date corresponding to 'day 0', as this will be used
  ## to create actual dates afterwards (as opposed to mere numbers of days).
  origin <- min(x$dates)

  ## Dates are reconstructed from info$pred$dates.x and origin).  Note that
  ## this is approximate, as dates are forced to be integers. A better option
  ## would be to convert the dates to numbers, but ggplot2 is no longer
  ## consistent when mixing up Date and decimal numbers (it works only in some
  ## cases / geom).
  dates <- origin + pred$dates.x
  info$pred <- cbind.data.frame(dates, info$pred)
  out <- list(lm = reg, info = info, origin = origin)
  class(out) <- "incidence_fit"
  out
}






#' @export
#' @rdname fit
#' @param ... further arguments passed to other methods (not used)

print.incidence_fit <- function(x, ...) {

  cat("<incidence_fit object>\n\n")
  cat("$lm: regression of log-incidence over time\n\n")

  cat("$info: list containing the following items:\n")
  cat("  $r (daily growth rate):\n")
  print(x$info$r)
  cat("\n  $r.conf (confidence interval):\n")
  print(x$info$r.conf)
  if (x$info$r[1] > 0) {
    cat("\n  $doubling (doubling time in days):\n")
    print(x$info$doubling)
    cat("\n  $doubling.conf (confidence interval):\n")
    print(x$info$doubling.conf)
  } else {
    cat("\n  $halving (halving time in days):\n")
    print(x$info$halving)
    cat("\n  $halving.conf (confidence interval):\n")
    print(x$info$halving.conf)
  }

  cat(sprintf(
    "\n  $pred: data.frame of incidence predictions (%d rows, %d columns)\n",
    nrow(x$info$pred), ncol(x$info$pred)))


  invisible(x)
}





## This function will take an existing 'incidence' plot object ('p') and add lines from an
## 'incidence_fit' object ('x')

#' @export
#' @rdname fit
#'
#' @param p An existing incidence plot.
#'
#' @param col_pal A color palette, defaulting
add_incidence_fit <- function(p, x, col_pal = incidence_pal1){

  df <- x$info$pred

  ## 'x' could be a list of fit, in which case all fits are added to the plot
  if (is.list(x) && !inherits(x, "incidence_fit")) {
    out <- p
    for (e in x) {
      if (inherits(e, "incidence_fit")) {
        out <- add_incidence_fit(out, e, col_pal)
      }
    }
    return(out)
  }

  out <- suppressMessages(
    p + ggplot2::geom_line(
      data = df,
      ggplot2::aes_string(x = "dates", y = "fit"), linetype = 1) +
      ggplot2::geom_line(
        data = df,
        ggplot2::aes_string(x = "dates", y = "lwr"), linetype = 2) +
      ggplot2::geom_line(
        data = df,
        ggplot2::aes_string(x = "dates", y = "upr"), linetype = 2)
  )


  if ("groups" %in% names(df)) {
    n.groups <- length(levels(df$groups))
    out <- out + ggplot2::aes_string(color = "groups") +
      ggplot2::scale_color_manual(values = col_pal(n.groups))
  }

  out
}





#' @export
#' @rdname fit

plot.incidence_fit <- function(x, ...){
  base <- ggplot2::ggplot()
  out <- add_incidence_fit(base, x, ...) +
    ggplot2::labs(x = "", y = "Predicted incidence")
  out
}
