#' Summarize Continuous and Categorical Variables with Optional Grouping
#'
#' `sum_stat` provides a summary of both continuous and categorical variables in a dataset.
#' Continuous variables can be summarized using mean (SD) or median (IQR), optionally with 95% confidence intervals.
#' Categorical variables are summarized as counts and percentages, optionally with confidence intervals.
#' Summaries can also be generated by a grouping variable, and a narrative interpretation is optionally printed.
#'
#' @param data A data.frame or tibble containing the variables to summarize.
#' @param by Optional. A single variable name (as string) to group the summary by.
#' @param statistic Character string indicating how to summarize continuous variables. Options are `"mean_sd"` (default) or `"med_iqr"`.
#' @param percent Character string specifying how percentages should be calculated for categorical variables: `"col"` (column-wise), `"row"` (row-wise), or `"none"` (no percentage). Default is `"col"`.
#' @param ci Logical. If TRUE, 95% confidence intervals are included in the summary for continuous and categorical variables. Default is FALSE.
#' @param conf Numeric. Confidence level for CI calculation (between 0 and 1). Default is 0.95.
#' @param report Logical. If TRUE, prints a narrative summary of the variables. Default is TRUE.
#'
#' @importFrom stats sd qnorm median quantile binom.test
#' @importFrom dplyr `%>%` group_by mutate summarise ungroup select arrange desc count bind_rows row_number
#' @importFrom tidyr pivot_wider
#' @importFrom tibble tibble
#'
#' @return A `flextable` object displaying the summarized variables, optionally including confidence intervals and group comparisons.
#' @export
#'
#' @examples
#' # Basic summary of iris dataset
#' sum_stat(iris, ci = FALSE, report = TRUE)
#'
#' # Summary of CO2 dataset by 'Treatment' with CI
#' sum_stat(CO2, by = "Treatment", ci = TRUE, report = TRUE, percent = "row")
sum_stat <- function(data,
                     by = NULL,
                     statistic = "mean_sd",   # "mean_sd" or "med_iqr"
                     percent = "col",         # "col", "row", "none"
                     ci = FALSE,
                     conf = 0.95,
                     report = FALSE) {

  data <- tibble::as_tibble(data)
  is_cat <- function(x) is.factor(x) || is.character(x)

  summary_list <- list()
  cont_vars <- c()
  cat_vars <- c()

  # Helper functions
  mean_ci <- function(x, conf) {
    x <- x[!is.na(x)]
    n <- length(x)
    if(n==0) return(c(mean=NA, sd=NA, ci_lower=NA, ci_upper=NA))
    se <- sd(x)/sqrt(n)
    m <- mean(x)
    c(mean=m, sd=sd(x), ci_lower=m - qnorm((1+conf)/2)*se, ci_upper=m + qnorm((1+conf)/2)*se)
  }

  median_ci <- function(x, conf) {
    x <- x[!is.na(x)]
    if(length(x)==0) return(list(median=NA, iqr=c(NA,NA), ci=c(NA,NA)))
    boots <- replicate(1000, median(sample(x, replace=TRUE)))
    ci_val <- quantile(boots, c((1-conf)/2, 1-(1-conf)/2))
    list(median=median(x), iqr=c(quantile(x,0.25), quantile(x,0.75)), ci=ci_val)
  }

  prop_ci <- function(x, total, conf) {
    if(total==0) return(paste0(x))
    bt <- binom.test(x, total, conf.level = conf)
    paste0(x, " (", round(x/total*100,2), "% [", round(bt$conf.int[1]*100,2), ", ", round(bt$conf.int[2]*100,2), "])")
  }

  # --- Summarize each variable ---
  for(colname in names(data)){
    if(!is.null(by) && colname==by) next
    var <- data[[colname]]

    # Categorical
    if(is_cat(var)){
      cat_vars <- c(cat_vars, colname)
      if(!is.null(by)){
        tbl <- data %>%
          count(.data[[colname]], .data[[by]]) %>%
          tidyr::complete(.data[[colname]], .data[[by]], fill=list(n=0))

        if(percent=="col") tbl <- tbl %>% group_by(.data[[by]])
        if(percent=="row") tbl <- tbl %>% group_by(.data[[colname]])

        tbl <- tbl %>%
          mutate(Value = if(ci) sapply(n, function(xi) prop_ci(xi, sum(n), conf))
                 else if(percent=="none") paste0(n)
                 else paste0(n, " (", round(n/sum(n)*100,2), "%)"),
                 Variable = colname,
                 Characteristic = as.character(.data[[colname]])) %>%
          ungroup() %>%
          select(Variable, Characteristic, .data[[by]], Value) %>%
          tidyr::pivot_wider(names_from=.data[[by]], values_from=Value, values_fill="")
      } else {
        tbl <- data %>%
          count(.data[[colname]]) %>%
          mutate(Value = if(ci) sapply(n, function(xi) prop_ci(xi, sum(n), conf))
                 else if(percent=="none") paste0(n)
                 else paste0(n, " (", round(n/sum(n)*100,2), "%)"),
                 Variable = colname,
                 Characteristic = as.character(.data[[colname]])) %>%
          select(Variable, Characteristic, Value)
      }

    } else {  # Continuous
      cont_vars <- c(cont_vars, colname)
      if(!is.numeric(var)) next
      make_value <- function(x){
        if(statistic=="mean_sd"){
          vals <- mean_ci(x, conf)
          if(ci) paste0(round(vals["mean"],2), " (", round(vals["sd"],2), ") [", round(vals["ci_lower"],2), ", ", round(vals["ci_upper"],2), "]")
          else paste0(round(vals["mean"],2), " (", round(vals["sd"],2), ")")
        } else{
          vals <- median_ci(x, conf)
          if(ci) paste0(round(vals$median,2), " (", round(vals$iqr[1],2), ", ", round(vals$iqr[2],2), ") [", round(vals$ci[1],2), ", ", round(vals$ci[2],2), "]")
          else paste0(round(vals$median,2), " (", round(vals$iqr[1],2), ", ", round(vals$iqr[2],2), ")")
        }
      }

      if(!is.null(by)){
        tbl <- data %>%
          group_by(.data[[by]]) %>%
          summarise(Value = make_value(.data[[colname]]), .groups="drop") %>%
          mutate(Variable = colname,
                 Characteristic = ifelse(statistic=="mean_sd","Mean (SD)","Median (IQR)")) %>%
          pivot_wider(names_from=.data[[by]], values_from=Value)
      } else{
        tbl <- tibble(
          Variable = colname,
          Characteristic = ifelse(statistic=="mean_sd","Mean (SD)","Median (IQR)"),
          Value = make_value(var)
        )
      }
    }

    summary_list[[colname]] <- tbl
  }

  summary_df <- bind_rows(summary_list) %>%
    group_by(Variable) %>%
    mutate(Variable = ifelse(row_number()==1, Variable, "")) %>%
    ungroup()

  # Footnote
  footnote_text <- if(statistic=="mean_sd"){
    if(ci) "* Mean (SD)/n(%) with 95% CI" else "* Mean (SD)/n(%)"
  } else{
    if(ci) "* Median (IQR)/n(%) with 95% CI" else "* Median (IQR)/n(%)"
  }

  # --- Narrative ---
  if(report){
    interpretation_list <- c()

    # Continuous variables
    for(colname in cont_vars){
      x <- data[[colname]]
      if(!is.numeric(x)) next
      if(!is.null(by)){
        groups <- unique(data[[by]])
        group_stats <- lapply(groups, function(g){
          tmp <- x[data[[by]]==g]; tmp <- tmp[!is.na(tmp)]
          vals <- mean_ci(tmp, conf)
          vals
        })
        names(group_stats) <- groups

        means <- sapply(group_stats, `[[`, "mean")
        sds   <- sapply(group_stats, `[[`, "sd")
        ci_lower <- sapply(group_stats, `[[`, "ci_lower")
        ci_upper <- sapply(group_stats, `[[`, "ci_upper")
        ranked <- order(means, decreasing = TRUE)

        # Determine narrative
        if(length(groups)==2){
          diff_perc <- abs(diff(means))/mean(means)
          if(diff_perc < 0.05){
            text <- paste0("Mean ", colname, " was similar between ", groups[1], " (", round(means[1],2), " \u00B1 ", round(sds[1],2),
                           if(ci) paste0(" [95% CI: ", round(ci_lower[1],2), ", ", round(ci_upper[1],2), "]") else "", ") and ",
                           groups[2], " (", round(means[2],2), " +/- ", round(sds[2],2),
                           if(ci) paste0(" [95% CI: ", round(ci_lower[2],2), ", ", round(ci_upper[2],2), "]") else "", ").")
          } else{
            high <- which.max(means); low <- which.min(means)
            text <- paste0("Mean ", colname, " was higher in ", groups[high], " (", round(means[high],2), " \u00B1 ", round(sds[high],2),
                           if(ci) paste0(" [95% CI: ", round(ci_lower[high],2), ", ", round(ci_upper[high],2), "]") else "", ") than in ",
                           groups[low], " (", round(means[low],2), " +/- ", round(sds[low],2),
                           if(ci) paste0(" [95% CI: ", round(ci_lower[low],2), ", ", round(ci_upper[low],2), "]") else "", ").")
          }
        } else{
          highest <- paste0(names(group_stats)[ranked[1]], " (", round(means[ranked[1]],2), " +/- ", round(sds[ranked[1]],2),
                            if(ci) paste0(" [95% CI: ", round(ci_lower[ranked[1]],2), ", ", round(ci_upper[ranked[1]],2), "]") else "", ")")
          lowest <- paste0(names(group_stats)[ranked[length(ranked)]], " (", round(means[ranked[length(ranked)]],2), " \u00B1 ", round(sds[ranked[length(ranked)]],2),
                           if(ci) paste0(" [95% CI: ", round(ci_lower[ranked[length(ranked)]],2), ", ", round(ci_upper[ranked[length(ranked)]],2), "]") else "", ")")
          if(length(ranked) > 2){
            intermediates <- paste0(names(group_stats)[ranked[2:(length(ranked)-1)]],
                                    " (", round(means[ranked[2:(length(ranked)-1)]],2), " +/-", round(sds[ranked[2:(length(ranked)-1)]],2),
                                    if(ci) paste0(" [95% CI: ", round(ci_lower[ranked[2:(length(ranked)-1)]],2), ", ", round(ci_upper[ranked[2:(length(ranked)-1)]],2), "]") else "", ")",
                                    collapse=", ")
            text <- paste0("Mean ", colname, " was highest in ", highest, ", followed by ", intermediates, ", and lowest in ", lowest, ".")
          } else{
            text <- paste0("Mean ", colname, " was highest in ", highest, " and lowest in ", lowest, ".")
          }
        }
        interpretation_list <- c(interpretation_list, text)
      } else{
        # Without by
        vals <- mean_ci(x, conf)
        text <- paste0("Mean ", colname, " was ", round(vals["mean"],2), " +/-", round(vals["sd"],2),
                       if(ci) paste0(" [95% CI: ", round(vals["ci_lower"],2), ", ", round(vals["ci_upper"],2), "]") else "", ".")
        interpretation_list <- c(interpretation_list, text)
      }
    }

    # Categorical variables
    for(colname in cat_vars){
      if(!is.null(by) && colname==by) next
      freq_tbl <- data %>%
        count(.data[[colname]]) %>%
        mutate(pct = round(n/sum(n)*100,2)) %>%
        arrange(desc(n))

      n_categories <- nrow(freq_tbl)

      if(length(unique(freq_tbl$n))==1){
        text <- paste0("For ", colname, ", all categories were similar in frequency (n=", freq_tbl$n[1], ", ", freq_tbl$pct[1], "%).")
      } else if(n_categories==2){
        diff_pct <- abs(freq_tbl$pct[1] - freq_tbl$pct[2])
        if(diff_pct < 5){
          text <- paste0("For ", colname, ", the two categories were similar in frequency.")
        } else{
          high <- freq_tbl[which.max(freq_tbl$pct),]
          low  <- freq_tbl[which.min(freq_tbl$pct),]
          text <- paste0("For ", colname, ", ", high[[colname]], " (n=", high$n, ", ", high$pct, "%) was more frequent than ", low[[colname]], " (n=", low$n, ", ", low$pct, "%).")
        }
      } else{
        highest <- paste0(freq_tbl[[colname]][1], " (n=", freq_tbl$n[1], ", ", freq_tbl$pct[1], "%)")
        lowest  <- paste0(freq_tbl[[colname]][n_categories], " (n=", freq_tbl$n[n_categories], ", ", freq_tbl$pct[n_categories], "%)")
        intermediates <- if(n_categories>2) paste0(freq_tbl[[colname]][2:(n_categories-1)],
                                                   " (n=", freq_tbl$n[2:(n_categories-1)], ", ", freq_tbl$pct[2:(n_categories-1)], "%)", collapse=", ") else ""
        text <- paste0("For ", colname, ", frequency was highest in ", highest,
                       if(intermediates!="" && length(unique(freq_tbl$n))>1) paste0(", followed by ", intermediates) else "",
                       if(length(unique(freq_tbl$n))>1) paste0(", and lowest in ", lowest, ".") else "")
      }
      interpretation_list <- c(interpretation_list, text)
    }

    cat(paste(interpretation_list, collapse="\n\n"), "\n")
  }

  # --- Return flextable ---
  total_n <- nrow(data)

  flextable::flextable(summary_df) %>%
    flextable::merge_v(j = "Variable") %>%
    flextable::autofit() %>%
    flextable::bold(part = "header") %>%
    flextable::set_header_labels(
      Characteristic = "Statistics",
      Value = paste0("N = ", total_n, "*")
    ) %>%
    flextable::add_footer_lines(footnote_text)



}


