#' Generate the GRN panel of the shiny app
#' @description These are the UI and server components of the GRN panel of the 
#' shiny app. It is generated by including 'GRN' in the panels.default argument
#' of \code{\link{generateShinyApp}}.
#' @inheritParams DEpanel
#' @return The UI and Server components of the shiny module, that can be used
#' within the UI and Server definitions of a shiny app.
#' @name GRNpanel
NULL

#' @rdname GRNpanel
#' @export
GRNpanelUI <- function(id, metadata, show = TRUE){
  ns <- NS(id)
  
  if(show){
    tabPanel(
      'GRN inference',
      sidebarLayout(
        
        sidebarPanel(
          selectInput(ns('n_networks'), 'Number of networks:', 1:4),
          
          selectInput(ns('condition'), 'Metadata column to use:', colnames(metadata)[-1], 
                      selected = colnames(metadata)[ncol(metadata)]),
          selectInput(ns('samples1'), 'Samples for GRN #1:', unique(metadata[[ncol(metadata)]]),
                      selected = unique(metadata[[ncol(metadata)]]), multiple = TRUE),
          conditionalPanel(
            id = ns('samples2'),
            ns=ns,
            condition = "input.n_networks >= 2",
            selectInput(ns('samples2'), 'Samples for GRN #2:', unique(metadata[[ncol(metadata)]]),
                        selected = unique(metadata[[ncol(metadata)]]), multiple = TRUE),
          ),
          conditionalPanel(
            id = ns('samples3'),
            ns=ns,
            condition = "input.n_networks >= 3",
            selectInput(ns('samples3'), 'Samples for GRN #3:', unique(metadata[[ncol(metadata)]]),
                        selected = unique(metadata[[ncol(metadata)]]), multiple = TRUE),
          ),
          conditionalPanel(
            id = ns('samples4'),
            ns=ns,
            condition = "input.n_networks >= 4",
            selectInput(ns('samples4'), 'Samples for GRN #4:', unique(metadata[[ncol(metadata)]]),
                        selected = unique(metadata[[ncol(metadata)]]), multiple = TRUE),
          ),
          
          selectInput(ns("targets"), "Target genes:", multiple = TRUE, choices = character(0)),
          shinyjs::disabled(actionButton(ns('goGRN'), label = 'Start GRN inference')),
          
          numericInput(ns("plotConnections"), "Connections to plot:", 5, 0, 100),
          textInput(ns('plotFileName'), 'File name for plot download', value ='GRNplot.html'),
          selectInput(ns('plotId'), 'Select plot to download:', 1:4),
          downloadButton(ns('download'), 'Download Plot'),
        ),
        
        mainPanel(
          fluidRow(
            column(6, visNetwork::visNetworkOutput(ns('plot1'))),
            column(
              6, 
              conditionalPanel(
                id = ns('plot2col'),
                ns = ns,
                condition = "input.n_networks >= 2",
                visNetwork::visNetworkOutput(ns('plot2'))
              )
            ),
            conditionalPanel(
              id = ns('plotrow'),
              ns = ns,
              condition = "input.n_networks >= 3",
              fluidRow(
                column(6, visNetwork::visNetworkOutput(ns('plot3'))),
                column(
                  6, 
                  conditionalPanel(
                    id = ns('plot2col'),
                    ns = ns,
                    condition = "input.n_networks >= 4",
                    visNetwork::visNetworkOutput(ns('plot4'))
                  )
                )
              )
            ),
            conditionalPanel(
              id = ns('includeUpset'),
              ns = ns,
              condition = "input.n_networks > 1",
              plotOutput(ns('plotUpset'))
            )
          )
        )
      )
    )
  }else{
    NULL
  }
}

#' @rdname GRNpanel
#' @export
GRNpanelServer <- function(id, expression.matrix, metadata, anno){
  
  stopifnot({
    is.reactive(expression.matrix)
    is.reactive(metadata)
    !is.reactive(anno)
  })
  
  moduleServer(id, function(input, output, session){
    
    expression.matrix.sub <- reactive(# Remove genes of constant expression
      expression.matrix()[matrixStats::rowMins(expression.matrix()) != 
                            matrixStats::rowMaxs(expression.matrix()), ]
    )
    observe({
      updateSelectizeInput(
        session, "targets", server = TRUE,
        choices = anno$NAME[anno$ENSEMBL %in% rownames(expression.matrix.sub())]
      )
    })
    
    observe({
      enable_condition <- length(input[["targets"]]) >= 1 &
        (input[["n_networks"]] < 1 | length(input[["samples1"]]) > 0) &
        (input[["n_networks"]] < 2 | length(input[["samples2"]]) > 0) &
        (input[["n_networks"]] < 3 | length(input[["samples3"]]) > 0) &
        (input[["n_networks"]] < 4 | length(input[["samples4"]]) > 0)
      if(enable_condition){
        shinyjs::enable("goGRN")
      }else{
        shinyjs::disable("goGRN")
      }
    }) %>%
      bindEvent(input[["targets"]], input[["n_networks"]], input[["samples1"]],
                input[["samples2"]], input[["samples3"]], input[["samples4"]])
    
    
    n_networks <- reactive(input[["n_networks"]]) %>% bindEvent(input[["goGRN"]])
    observe(updateSelectInput(session, "plotId", choices = seq_len(n_networks())))
    
    GRNresults1 <- reactive({
      shinyjs::disable("goGRN")
      if(n_networks() >= 1){
        weightMat <- infer_GRN(
          expression.matrix = expression.matrix.sub(), 
          metadata = metadata(), 
          anno = anno, 
          targets = input[["targets"]], 
          condition = input[["condition"]], 
          samples = input[["samples1"]], 
          inference_method = "GENIE3"
        )
        shinyjs::enable("goGRN")
      }else{
        weightMat <- NULL
      }
      weightMat
    }) %>%
      bindEvent(input[["goGRN"]])
    GRNresults2 <- reactive({
      shinyjs::disable("goGRN")
      if(n_networks() >= 2){
        weightMat <- infer_GRN(
          expression.matrix = expression.matrix.sub(), 
          metadata = metadata(), 
          anno = anno, 
          targets = input[["targets"]], 
          condition = input[["condition"]], 
          samples = input[["samples2"]], 
          inference_method = "GENIE3"
        )
        shinyjs::enable("goGRN")
      }else{
        weightMat <- NULL
      }
      weightMat
    }) %>%
      bindEvent(input[["goGRN"]])
    GRNresults3 <- reactive({
      shinyjs::disable("goGRN")
      if(n_networks() >= 3){
        weightMat <- infer_GRN(
          expression.matrix = expression.matrix.sub(), 
          metadata = metadata(), 
          anno = anno, 
          targets = input[["targets"]], 
          condition = input[["condition"]], 
          samples = input[["samples3"]], 
          inference_method = "GENIE3"
        )
        shinyjs::enable("goGRN")
      }else{
        weightMat <- NULL
      }
      weightMat
    }) %>%
      bindEvent(input[["goGRN"]])
    GRNresults4 <- reactive({
      shinyjs::disable("goGRN")
      if(n_networks() >= 4){
        weightMat <- infer_GRN(
          expression.matrix = expression.matrix.sub(), 
          metadata = metadata(), 
          anno = anno, 
          targets = input[["targets"]], 
          condition = input[["condition"]], 
          samples = input[["samples4"]], 
          inference_method = "GENIE3"
        )
        shinyjs::enable("goGRN")
      }else{
        weightMat <- NULL
      }
      weightMat
    }) %>%
      bindEvent(input[["goGRN"]])
    
    weightMatList <- reactive({
      weightMatList <- list()
      if(n_networks() >= 1) {
        weightMatList[[1]] <- GRNresults1()
        if(n_networks() >= 2) {
          weightMatList[[2]] <- GRNresults2()
          if(n_networks() >= 3) {
            weightMatList[[3]] <- GRNresults3()
            if(n_networks() >= 4) {
              weightMatList[[4]] <- GRNresults4()
            }
          }
        }
      }
      weightMatList
    })
    
    recurring_regulators <- reactive({
      find_regulators_with_recurring_edges(weightMatList(), input[["plotConnections"]])
    })
    
    GRNplot1 <- reactive(plot_GRN(
      weightMat = GRNresults1(), 
      anno = anno, 
      plotConnections = input[["plotConnections"]], 
      plot_position_grid = 1, 
      n_networks = n_networks(),
      recurring_regulators = recurring_regulators()
    ))
    GRNplot2 <- reactive(plot_GRN(
      weightMat = GRNresults2(), 
      anno = anno, 
      plotConnections = input[["plotConnections"]], 
      plot_position_grid = 2, 
      n_networks = n_networks(),
      recurring_regulators = recurring_regulators()
    ))
    GRNplot3 <- reactive(plot_GRN(
      weightMat = GRNresults3(), 
      anno = anno, 
      plotConnections = input[["plotConnections"]], 
      plot_position_grid = 3, 
      n_networks = n_networks(),
      recurring_regulators = recurring_regulators()
    )) 
    GRNplot4 <- reactive(plot_GRN(
      weightMat = GRNresults4(), 
      anno = anno, 
      plotConnections = input[["plotConnections"]], 
      plot_position_grid = 4, 
      n_networks = n_networks(),
      recurring_regulators = recurring_regulators()
    )) 
    
    upsetPlot <- reactive(plot_upset(weightMatList(), input[["plotConnections"]])) 
    
    output[['plot1']] <- visNetwork::renderVisNetwork(GRNplot1())
    output[['plot2']] <- visNetwork::renderVisNetwork(GRNplot2())
    output[['plot3']] <- visNetwork::renderVisNetwork(GRNplot3())
    output[['plot4']] <- visNetwork::renderVisNetwork(GRNplot4())
    output[['plotUpset']] <- renderPlot(upsetPlot())
    
    output[['download']] <- downloadHandler(
      filename = function() {input[['plotFileName']]},
      content = function(file) {
        GRNplot <- get(paste0("GRNplot", input[["plotId"]]))()
        GRNplot %>% visNetwork::visSave(file)
      }
    )
    
  })
}