#' @title Data Preparation Functions
#' @name data_preparation
#' @description Functions for preparing and transforming price data for analysis.
NULL

#' Prepare Panel Data from Wide Format Matrices
#'
#' Converts wide-format price matrices (with Year as first column and sectors
#' as subsequent columns) into long-format panel data suitable for panel
#' regression analysis.
#'
#' @param direct_prices Data frame with direct prices (labor value-based).
#'   First column must be 'Year', remaining columns are sector values.
#' @param production_prices Data frame with prices of production.
#'   Must have same structure as direct_prices.
#' @param log_transform Logical. Apply natural log transformation. Default TRUE.
#'
#' @return A data frame in panel (long) format with columns:
#' \describe{
#'   \item{year}{Year of observation}
#'   \item{sector}{Sector identifier}
#'   \item{sector_id}{Numeric sector identifier}
#'   \item{time}{Time index (year minus minimum year plus 1)}
#'   \item{direct_price}{Direct price value}
#'   \item{production_price}{Price of production value}
#'   \item{log_direct}{Log of direct price (if log_transform = TRUE)}
#'   \item{log_production}{Log of production price (if log_transform = TRUE)}
#' }
#'
#' @examples
#' set.seed(123)
#' years <- 2000:2010
#' sectors <- c("Agriculture", "Manufacturing", "Services")
#'
#' direct <- data.frame(
#'   Year = years,
#'   Agriculture = 100 + cumsum(rnorm(11)),
#'   Manufacturing = 120 + cumsum(rnorm(11)),
#'   Services = 90 + cumsum(rnorm(11))
#' )
#'
#' production <- data.frame(
#'   Year = years,
#'   Agriculture = 102 + cumsum(rnorm(11)),
#'   Manufacturing = 118 + cumsum(rnorm(11)),
#'   Services = 92 + cumsum(rnorm(11))
#' )
#'
#' panel <- prepare_panel_data(direct, production)
#' head(panel)
#'
#' @export
prepare_panel_data <- function(direct_prices, production_prices,
                                log_transform = TRUE) {

    if (!is.data.frame(direct_prices) || !is.data.frame(production_prices)) {
        stop("Both inputs must be data frames.")
    }

    if (!("Year" %in% names(direct_prices))) {
        stop("direct_prices must have a 'Year' column.")
    }

    if (!("Year" %in% names(production_prices))) {
        stop("production_prices must have a 'Year' column.")
    }

    years <- direct_prices$Year
    sectors <- names(direct_prices)[-1L]

    sectors_prod <- names(production_prices)[-1L]
    if (!all(sectors %in% sectors_prod)) {
        missing <- setdiff(sectors, sectors_prod)
        stop(sprintf(
            "Sectors in direct_prices not found in production_prices: %s",
            paste(missing, collapse = ", ")
        ))
    }

    n_years <- length(years)
    n_sectors <- length(sectors)

    panel_data <- expand.grid(
        year = years,
        sector = sectors,
        stringsAsFactors = FALSE
    )

    panel_data$direct_price <- NA_real_
    panel_data$production_price <- NA_real_

    for (s in sectors) {
        idx <- panel_data$sector == s
        panel_data$direct_price[idx] <- as.numeric(direct_prices[[s]])
        panel_data$production_price[idx] <- as.numeric(production_prices[[s]])
    }

    if (log_transform) {

        if (any(panel_data$direct_price <= 0, na.rm = TRUE)) {
            warning("Non-positive values in direct_price; log will produce NaN.")
        }
        if (any(panel_data$production_price <= 0, na.rm = TRUE)) {
            warning("Non-positive values in production_price; log will produce NaN.")
        }

        panel_data$log_direct <- log(panel_data$direct_price)
        panel_data$log_production <- log(panel_data$production_price)
    }

    panel_data$sector_id <- as.numeric(factor(panel_data$sector))
    panel_data$time <- panel_data$year - min(panel_data$year) + 1L

    panel_data <- panel_data[order(panel_data$sector, panel_data$year), ]
    rownames(panel_data) <- NULL

    panel_data
}


#' Validate Panel Data Structure
#'
#' Checks that panel data has required columns and valid structure.
#'
#' @param panel_data Data frame to validate.
#' @param require_log Logical. Check for log-transformed columns. Default TRUE.
#'
#' @return TRUE invisibly if valid, otherwise stops with informative error.
#'
#' @examples
#' set.seed(123)
#' panel <- data.frame(
#'   year = rep(2000:2002, 3),
#'   sector = rep(c("A", "B", "C"), each = 3),
#'   log_direct = rnorm(9),
#'   log_production = rnorm(9)
#' )
#' validate_panel_data(panel)
#'
#' @export
validate_panel_data <- function(panel_data, require_log = TRUE) {

    if (!is.data.frame(panel_data)) {
        stop("panel_data must be a data frame.")
    }

    required_cols <- c("year", "sector")
    if (require_log) {
        required_cols <- c(required_cols, "log_direct", "log_production")
    }

    missing_cols <- setdiff(required_cols, names(panel_data))
    if (length(missing_cols) > 0L) {
        stop(sprintf(
            "Missing required columns: %s",
            paste(missing_cols, collapse = ", ")
        ))
    }

    if (nrow(panel_data) == 0L) {
        stop("panel_data has no observations.")
    }

    n_sectors <- length(unique(panel_data$sector))
    n_years <- length(unique(panel_data$year))

    if (n_sectors < 2L) {
        warning("Panel has fewer than 2 sectors; some methods may not work.")
    }

    if (n_years < 3L) {
        warning("Panel has fewer than 3 time periods; some methods may not work.")
    }

    invisible(TRUE)
}


#' Create Mundlak-Transformed Panel Data
#'
#' Adds sector-level means and within-deviations for Mundlak/CRE estimation.
#'
#' @param panel_data Data frame with panel data.
#' @param x_var Character string. Name of the explanatory variable.
#'   Default "log_direct".
#'
#' @return Panel data with additional columns:
#' \describe{
#'   \item{x_mean_sector}{Sector-level mean of x_var}
#'   \item{x_within}{Within-sector deviation (x - sector mean)}
#' }
#'
#' @examples
#' set.seed(123)
#' panel <- data.frame(
#'   year = rep(2000:2002, 3),
#'   sector = rep(c("A", "B", "C"), each = 3),
#'   log_direct = rnorm(9, mean = 5),
#'   log_production = rnorm(9, mean = 5)
#' )
#' panel_mundlak <- create_mundlak_data(panel)
#' head(panel_mundlak)
#'
#' @export
create_mundlak_data <- function(panel_data, x_var = "log_direct") {

    if (!(x_var %in% names(panel_data))) {
        stop(sprintf("Variable '%s' not found in panel_data.", x_var))
    }

    if (!("sector" %in% names(panel_data))) {
        stop("panel_data must have a 'sector' column.")
    }

    sectors <- unique(panel_data$sector)

    panel_data$x_mean_sector <- NA_real_
    panel_data$x_within <- NA_real_

    for (s in sectors) {
        idx <- panel_data$sector == s
        x_vals <- panel_data[[x_var]][idx]
        sector_mean <- mean(x_vals, na.rm = TRUE)

        panel_data$x_mean_sector[idx] <- sector_mean
        panel_data$x_within[idx] <- x_vals - sector_mean
    }

    panel_data
}


#' Create Time-Series Aggregated Data
#'
#' Aggregates panel data to time series by computing means across sectors.
#'
#' @param panel_data Data frame with panel data.
#' @param vars Character vector of variables to aggregate.
#'   Default c("log_direct", "log_production").
#'
#' @return Data frame with one row per year containing aggregated values.
#'
#' @examples
#' set.seed(123)
#' panel <- data.frame(
#'   year = rep(2000:2005, 3),
#'   sector = rep(c("A", "B", "C"), each = 6),
#'   log_direct = rnorm(18, mean = 5),
#'   log_production = rnorm(18, mean = 5)
#' )
#' ts_agg <- aggregate_to_timeseries(panel)
#' head(ts_agg)
#'
#' @export
aggregate_to_timeseries <- function(panel_data,
                                     vars = c("log_direct", "log_production")) {

    if (!("year" %in% names(panel_data))) {
        stop("panel_data must have a 'year' column.")
    }

    missing_vars <- setdiff(vars, names(panel_data))
    if (length(missing_vars) > 0L) {
        stop(sprintf(
            "Variables not found in panel_data: %s",
            paste(missing_vars, collapse = ", ")
        ))
    }

    years <- sort(unique(panel_data$year))

    result <- data.frame(year = years)

    for (v in vars) {
        agg_vals <- sapply(years, function(y) {
            vals <- panel_data[[v]][panel_data$year == y]
            mean(vals, na.rm = TRUE)
        })
        col_name <- paste0(v, "_mean")
        result[[col_name]] <- agg_vals
    }

    result
}


#' Prepare Log-Transformed Matrices
#'
#' Extracts numeric columns from price data frames and applies log transform.
#'
#' @param direct_prices Data frame with direct prices.
#' @param production_prices Data frame with prices of production.
#' @param exclude_cols Character vector of columns to exclude.
#'   Default c("Year").
#'
#' @return A list containing:
#' \describe{
#'   \item{X_log}{Log-transformed matrix of direct prices}
#'   \item{Y_log}{Log-transformed matrix of production prices}
#'   \item{complete_cases}{Logical vector indicating complete cases}
#'   \item{X_clean}{Subset of X_log with complete cases}
#'   \item{Y_clean}{Subset of Y_log with complete cases}
#' }
#'
#' @examples
#' set.seed(123)
#' direct <- data.frame(
#'   Year = 2000:2005,
#'   A = runif(6, 100, 200),
#'   B = runif(6, 100, 200)
#' )
#' production <- data.frame(
#'   Year = 2000:2005,
#'   A = runif(6, 100, 200),
#'   B = runif(6, 100, 200)
#' )
#' matrices <- prepare_log_matrices(direct, production)
#' str(matrices)
#'
#' @export
prepare_log_matrices <- function(direct_prices, production_prices,
                                  exclude_cols = c("Year")) {

    cols_to_use <- setdiff(names(direct_prices), exclude_cols)
    cols_to_use <- intersect(cols_to_use, names(production_prices))

    if (length(cols_to_use) == 0L) {
        stop("No common numeric columns found between the two data frames.")
    }

    X_matrix <- as.matrix(direct_prices[, cols_to_use, drop = FALSE])
    Y_matrix <- as.matrix(production_prices[, cols_to_use, drop = FALSE])

    mode(X_matrix) <- "numeric"
    mode(Y_matrix) <- "numeric"

    X_log <- log(X_matrix)
    Y_log <- log(Y_matrix)

    complete_cases <- stats::complete.cases(cbind(X_log, Y_log)) &
        apply(is.finite(cbind(X_log, Y_log)), 1L, all)

    list(
        X_log = X_log,
        Y_log = Y_log,
        complete_cases = complete_cases,
        X_clean = X_log[complete_cases, , drop = FALSE],
        Y_clean = Y_log[complete_cases, , drop = FALSE],
        column_names = cols_to_use
    )
}
