# Helper function to remove outliers in dwell-time estimations.
remove_outliers <- function(x) {
  q1 <- quantile(x, 0.25)
  q3 <- quantile(x, 0.75)
  iqr <- q3 - q1
  lower_bound <- q1 - 1.5 * iqr
  upper_bound <- q3 + 1.5 * iqr
  x[x >= lower_bound & x <= upper_bound]
}

clusterHSMM <- function(J, x, obsdist, dwelldist, min_consecutive = 2, shift = FALSE, seed = NULL) {

  # Set seed if provided for reproducibility
  if (!is.null(seed)) {
    set.seed(seed)
  }

  # Step 1: Initial k-means clustering on observed data
  clustering <- kmeans(x, centers = J)
  means <- tapply(x, clustering$cluster, mean)
  vars <- tapply(x, clustering$cluster, var)

  # Track mapping from original cluster labels to ordered labels for identifiability
  original_to_ordered <- numeric(J)

  # Step 2: Observation distribution parameter initialization
  if (obsdist == "norm") {
    # Normal distribution: mean and standard deviation from clusters
    mean0 <- means
    sd0 <- sqrt(vars)
    idx <- order(mean0)  # Order by mean for identifiability
    original_to_ordered[idx] <- 1:J
    mean0 <- mean0[idx]
    sd0 <- sd0[idx]
    obspar <- list(mean = mean0, sd = sd0)

  } else if (obsdist == "pois") {
    # Poisson distribution: lambda equals cluster means
    lambda0 <- means
    idx <- order(lambda0)
    original_to_ordered[idx] <- 1:J
    lambda0 <- lambda0[idx]
    obspar <- list(lambda = lambda0)

  } else if (obsdist == "weibull") {
    # Weibull distribution: method of moments approximation
    shape0 <- numeric(J)
    scale0 <- numeric(J)
    for (j in 1:J) {
      cluster_data <- x[clustering$cluster == j]
      mean_j <- mean(cluster_data)
      var_j <- var(cluster_data)
      cv <- sqrt(var_j) / mean_j  # Coefficient of variation
      # Empirical approximation formulas
      shape0[j] <- (0.9874 / cv)^1.0362
      scale0[j] <- mean_j / gamma(1 + 1 / shape0[j])
    }
    idx <- order(scale0)
    original_to_ordered[idx] <- 1:J
    shape0 <- shape0[idx]
    scale0 <- scale0[idx]
    obspar <- list(shape = shape0, scale = scale0)

  } else if (obsdist == "zip") {
    # Zero-inflated Poisson: cluster non-zero values, fixed zero-inflation
    pi0 <- rep(0.5, J) # Fixed zero-inflation probabilities
    x2 <- x[-which(x == 0)]  # Remove zeros for clustering
    clustering2 <- kmeans(x2, centers = J, nstart = 25)
    lambda0 <- tapply(x2, clustering2$cluster, mean)
    idx <- order(lambda0)
    original_to_ordered[idx] <- 1:J
    lambda0 <- lambda0[idx]
    obspar <- list(lambda = lambda0, pi = pi0)

  } else if (obsdist == "nbinom") {
    # Negative binomial: method of moments with absolute value for stability
    mu0 <- means
    size0 <- mu0^2 / abs((vars - mu0))  # Abs to handle negative values
    idx <- order(mu0)
    original_to_ordered[idx] <- 1:J
    mu0 <- mu0[idx]
    size0 <- size0[idx]
    obspar <- list(size = size0, mu = mu0)

  } else if (obsdist == "zinb") {
    # Zero-inflated negative binomial
    mu0 <- means
    size0 <- mu0^2 / abs((vars - mu0))
    idx <- order(mu0)
    original_to_ordered[idx] <- 1:J
    pi0 <- rep(0.5, J)  # Fixed zero-inflation probabilities
    mu0 <- mu0[idx]
    size0 <- size0[idx]
    obspar <- list(size = size0, mu = mu0, pi = pi0)

  } else if (obsdist == "exp") {
    # Exponential distribution: rate as inverse of mean
    rate0 <- 1 / means
    idx <- order(rate0)
    original_to_ordered[idx] <- 1:J
    rate0 <- rate0[idx]
    obspar <- list(rate = rate0)

  } else if (obsdist == "gamma") {
    # Gamma distribution: method of moments
    shape0 <- means^2 / vars
    rate0 <- means / vars
    idx <- order(means)  # Order by mean
    original_to_ordered[idx] <- 1:J
    shape0 <- shape0[idx]
    rate0 <- rate0[idx]
    obspar <- list(shape = shape0, rate = rate0)

  } else if (obsdist == "lnorm") {
    # Log-normal distribution: cluster on log-transformed data
    logx <- log(x)
    clustering <- kmeans(log(x), J)
    meanlog0 <- tapply(logx, clustering$cluster, mean)
    sdlog0 <- tapply(logx, clustering$cluster, sd)
    idx <- order(meanlog0)
    original_to_ordered[idx] <- 1:J
    meanlog0 <- meanlog0[idx]
    sdlog0 <- sdlog0[idx]
    obspar <- list(meanlog = meanlog0, sdlog = sdlog0)

  } else if (obsdist == "ZInormal") {
    # Zero-inflated Normal: cluster non-zero values, uniform zero-inflation
    pi0 <- rep(0.5, J)  # Equal zero-inflation probabilities
    x2 <- x[-which(x == 0)]  # Remove zeros for clustering

    clustering2 <- kmeans(x2, centers = J, nstart = 25)
    means2 <- tapply(x2, clustering2$cluster, mean)
    vars2 <- tapply(x2, clustering2$cluster, var)

    mean0 <- means2
    sd0 <- sqrt(pmax(vars2, 0.01))  # Ensure positive variance
    idx <- order(mean0)
    original_to_ordered[idx] <- 1:J
    mean0 <- mean0[idx]
    sd0 <- sd0[idx]
    obspar <- list(pi = pi0, mean = mean0, sd = sd0)

  } else if (obsdist == "ZIgamma") {
    # Zero-inflated Gamma: cluster non-zero values, uniform zero-inflation
    pi0 <- rep(0.5, J)  # Equal zero-inflation probabilities
    x2 <- x[-which(x == 0)]  # Remove zeros for clustering

    clustering2 <- kmeans(x2, centers = J, nstart = 25)
    means2 <- tapply(x2, clustering2$cluster, mean)
    vars2 <- tapply(x2, clustering2$cluster, var)

    # Method of moments for Gamma parameters
    shape0 <- (means2^2) / pmax(vars2, means2 * 0.1)  # Ensure positive variance
    rate0 <- means2 / pmax(vars2, means2 * 0.1)

    # Ensure valid Gamma parameters
    shape0 <- pmax(shape0, 0.1)
    rate0 <- pmax(rate0, 0.1)

    idx <- order(shape0)
    original_to_ordered[idx] <- 1:J
    shape0 <- shape0[idx]
    rate0 <- rate0[idx]
    obspar <- list(pi = pi0, shape = shape0, rate = rate0)

  } else {
    stop("observation distribution not supported")
  }

  # Step 3: Relabel clusters according to ordering and extract state sequence with buffering
  clustering$cluster <- original_to_ordered[clustering$cluster]
  dwelltimes <- vector("list", J)
  current_state <- clustering$cluster[1]
  current_duration <- 1
  buffer <- numeric(0)  # Buffer for handling brief state transitions

  # Process sequence with buffering mechanism to handle noise
  for (i in 2:length(x)) {
    if (clustering$cluster[i] == current_state) {
      # Same state: add buffered duration and reset buffer
      current_duration <- current_duration + length(buffer) + 1
      buffer <- numeric(0)
    } else if (length(buffer) == 0) {
      # Start new potential transition
      buffer <- clustering$cluster[i]
    } else if (clustering$cluster[i] == buffer[length(buffer)]) {
      # Continue same potential transition
      buffer <- c(buffer, clustering$cluster[i])
      # Confirm transition if buffer reaches minimum consecutive length
      if (length(buffer) >= min_consecutive) {
        dwelltimes[[current_state]] <- c(dwelltimes[[current_state]], current_duration)
        current_state <- clustering$cluster[i]
        current_duration <- length(buffer)
        buffer <- numeric(0)
      }
    } else {
      # Different state in buffer: add buffer to current duration and start new buffer
      current_duration <- current_duration + length(buffer)
      buffer <- clustering$cluster[i]
    }
  }

  # Handle final state and any remaining buffer
  current_duration <- current_duration + length(buffer)
  dwelltimes[[current_state]] <- c(dwelltimes[[current_state]], current_duration)
  dwelltimes2 <- unlist(dwelltimes[1:J])  # All dwell times combined

  # Step 4: Dwell time distribution parameter estimation
  if (dwelldist == "pois") {
    # Poisson dwell time distribution
    dwellpar <- list(lambda = numeric(J), shift = numeric(J))

    if (shift == FALSE) {
      # No shifting: standard Poisson with shift = 1
      dwellpar$shift <- rep(1, J)

      for (j in 1:J) {
        if (length(dwelltimes[[j]]) > 0) {
          clean_durations <- remove_outliers(dwelltimes[[j]])
          if (length(clean_durations) > 0) {
            dwellpar$lambda[j] <- mean(clean_durations - dwellpar$shift[j])
          }
        }
      }
    } else {
      # Adaptive shifting based on minimum observed duration
      for (j in 1:J) {
        if (length(dwelltimes[[j]]) > 0) {
          clean_durations <- remove_outliers(dwelltimes[[j]])
          if (length(clean_durations) > 0) {
            init_lam <- mean(clean_durations)
            min_quantile <- qpois(1e-08, init_lam)  # Very low quantile
            dwellpar$shift[j] <- max(1, abs(min(clean_durations) - min_quantile))
            dwellpar$lambda[j] <- mean(clean_durations - dwellpar$shift[j])
          }
        }
      }
    }

  } else if (dwelldist == "nbinom") {
    # Negative binomial dwell time distribution
    dwellpar <- list(size = numeric(J), mu = numeric(J), shift = numeric(J))

    if (shift == FALSE) {
      # No shifting: standard negative binomial with shift = 1
      dwellpar$shift <- rep(1, J)

      for (j in 1:J) {
        if (length(dwelltimes[[j]]) > 0) {
          clean_durations <- remove_outliers(dwelltimes[[j]])
          if (length(clean_durations) > 0) {
            adjusted_durations <- clean_durations - dwellpar$shift[j]
            dwellpar$mu[j] <- mean(adjusted_durations)
            # Method of moments for size parameter
            dwellpar$size[j] <- dwellpar$mu[j]^2 / (var(adjusted_durations) - dwellpar$mu[j])
          }
        }
      }
    } else {
      # Adaptive shifting
      for (j in 1:J) {
        if (length(dwelltimes[[j]]) > 0) {
          clean_durations <- remove_outliers(dwelltimes[[j]])
          if (length(clean_durations) > 0) {
            mean_dur <- mean(clean_durations)
            var_dur <- var(clean_durations)
            dwellpar$shift[j] <- max(1, min(clean_durations) - 1)
            adjusted_durations <- clean_durations - dwellpar$shift[j]
            dwellpar$mu[j] <- mean(adjusted_durations)
            dwellpar$size[j] <- dwellpar$mu[j]^2 / (var(adjusted_durations) - dwellpar$mu[j])
          }
        }
      }
    }

  } else if (dwelldist == "betabinom") {
    # Beta-binomial dwell time distribution
    dwellpar <- list(size = numeric(J), alpha = numeric(J), beta = numeric(J), shift = numeric(J))

    if (shift == FALSE) {
      # No shifting: standard beta-binomial with shift = 1
      dwellpar$shift <- rep(1, J)

      for (j in 1:J) {
        if (length(dwelltimes[[j]]) > 0) {
          clean_durations <- remove_outliers(dwelltimes[[j]])
          if (length(clean_durations) > 0) {
            adjusted_durations <- clean_durations - dwellpar$shift[j]
            dwellpar$size[j] <- max(adjusted_durations) + 5  # Set size with buffer
            y_bar <- mean(adjusted_durations) / dwellpar$size[j]  # Proportion
            s2 <- var(adjusted_durations) / (dwellpar$size[j]^2)  # Scaled variance

            # Method of moments for beta parameters (if variance condition met)
            if (s2 < y_bar * (1 - y_bar)) {
              rho <- (s2 - y_bar * (1 - y_bar) / dwellpar$size[j]) /
                (y_bar * (1 - y_bar) * (1 - 1 / dwellpar$size[j]))
              dwellpar$alpha[j] <- y_bar * ((1 - rho) / rho)
              dwellpar$beta[j] <- (1 - y_bar) * ((1 - rho) / rho)
            }
          }
        }
      }
    } else {
      # Adaptive shifting
      for (j in 1:J) {
        if (length(dwelltimes[[j]]) > 0) {
          clean_durations <- remove_outliers(dwelltimes[[j]])
          if (length(clean_durations) > 0) {
            dwellpar$shift[j] <- max(1, min(clean_durations) - 1)
            adjusted_durations <- clean_durations - dwellpar$shift[j]
            dwellpar$size[j] <- max(adjusted_durations) + 5
            y_bar <- mean(adjusted_durations) / dwellpar$size[j]
            s2 <- var(adjusted_durations) / (dwellpar$size[j]^2)

            if (s2 < y_bar * (1 - y_bar)) {
              rho <- (s2 - y_bar * (1 - y_bar) / dwellpar$size[j]) /
                (y_bar * (1 - y_bar) * (1 - 1 / dwellpar$size[j]))
              dwellpar$alpha[j] <- y_bar * ((1 - rho) / rho)
              dwellpar$beta[j] <- (1 - y_bar) * ((1 - rho) / rho)
            }
          }
        }
      }
    }

  } else if (dwelldist == "hyper") {
    # Hypergeometric dwell time distribution
    dwellpar <- list(m = numeric(J), n = numeric(J), k = numeric(J), shift = numeric(J))

    if (shift == FALSE) {
      # No shifting: standard hypergeometric with shift = 1
      dwellpar$shift <- rep(1, J)

      for (j in 1:J) {
        if (length(dwelltimes[[j]]) > 0) {
          clean_durations <- remove_outliers(dwelltimes[[j]])
          if (length(clean_durations) > 0) {
            adjusted_durations <- clean_durations - dwellpar$shift[j]
            mean_dur <- mean(adjusted_durations)
            var_dur <- var(adjusted_durations)

            # Method of moments for hypergeometric parameters
            dwellpar$k[j] <- ceiling(max(adjusted_durations))  # Number of draws
            # Total population size N
            N <- ceiling(mean_dur * (mean_dur * (dwellpar$k[j] - mean_dur)) / var_dur)
            N <- max(N, dwellpar$k[j] + 1)  # Ensure N > k
            # Number of success states in population
            dwellpar$m[j] <- round((N * mean_dur) / dwellpar$k[j])
            dwellpar$m[j] <- max(dwellpar$k[j], dwellpar$m[j])  # Ensure m >= k
            # Number of failure states
            dwellpar$n[j] <- N - dwellpar$m[j]
            dwellpar$m[j] <- max(dwellpar$k[j], dwellpar$m[j])
            dwellpar$n[j] <- max(1, dwellpar$n[j])  # Ensure n >= 1

            # Additional constraint check
            if (dwellpar$k[j] > dwellpar$m[j]) {
              dwellpar$m[j] <- dwellpar$k[j]
            }
          }
        }
      }
    } else {
      # Adaptive shifting
      for (j in 1:J) {
        if (length(dwelltimes[[j]]) > 0) {
          clean_durations <- remove_outliers(dwelltimes[[j]])
          if (length(clean_durations) > 0) {
            dwellpar$shift[j] <- max(1, min(clean_durations) - 1)
            adjusted_durations <- clean_durations - dwellpar$shift[j]
            mean_dur <- mean(adjusted_durations)
            var_dur <- var(adjusted_durations)

            dwellpar$k[j] <- ceiling(max(adjusted_durations))
            N <- ceiling(mean_dur * (mean_dur * (dwellpar$k[j] - mean_dur)) / var_dur)
            N <- max(N, dwellpar$k[j] + 1)
            dwellpar$m[j] <- round((N * mean_dur) / dwellpar$k[j])
            dwellpar$m[j] <- max(dwellpar$k[j], dwellpar$m[j])
            dwellpar$n[j] <- N - dwellpar$m[j]
            dwellpar$m[j] <- max(dwellpar$k[j], dwellpar$m[j])
            dwellpar$n[j] <- max(1, dwellpar$n[j])

            if (dwellpar$k[j] > dwellpar$m[j]) {
              dwellpar$m[j] <- dwellpar$k[j]
            }
          }
        }
      }
    }
  } else {
    stop("dwell distribution not supported")
  }

  # Step 5: Initialize remaining HSMM parameters
  # Uniform initial state probabilities
  delta <- c()
  for (j in 1:J) {
    delta[j] <- 1 / J
  }

  # Uniform transition probabilities (no self-transitions in HSMM)
  Pi <- matrix(1 / J, J, J)
  diag(Pi) <- 0  # No self-transitions
  Pi <- Pi / apply(Pi, 1, sum)  # Normalize rows

  # Return complete parameter set
  result <- list()
  result$obspar <- obspar
  result$dwellpar <- dwellpar
  result$Pi <- Pi
  result$delta <- delta
  result$durations <- dwelltimes2  # All extracted dwell times
  return(result)
}


