package org.ddahl.rscala

import java.net._
import java.io._
import scala.tools.nsc.interpreter.IMain
import scala.Console.{baosOut, baosErr, psOut, psErr, withOut, withErr}
import scala.annotation.tailrec

import Protocol._

class ScalaServer private (repl: IMain, portsFilename: String, debugger: Debugger) {

  private val sockets = new ScalaSockets(portsFilename,debugger)
  import sockets.{in, out, socketIn, socketOut}
  private val R = org.ddahl.rscala.RClient(in,out,debugger)
  repl.bind("R","org.ddahl.rscala.RClient",R)
  repl.bind("$intp","scala.tools.nsc.interpreter.IMain",repl)
  baosOut.reset
  baosErr.reset

  private var functionResult: Any = null
  private var cacheMap = new scala.collection.mutable.ArrayBuffer[Any]()
  private val cacheMapExtractor = """\.(\d+)""".r
  private val functionParameters = new scala.collection.mutable.ArrayBuffer[Any]()
  private val functionMap = new scala.collection.mutable.HashMap[Any,java.lang.reflect.Method]()

  private def storeFunction(f: Any): Unit = {
    val m = f.getClass.getMethods.filter(_.getName == "apply") match {
      case Array() => null
      case Array(method) => method
      case Array(method1,method2) => method1
      case _ => null
    }
    functionMap(f) = m
  }

  private def callFunction(f: Any): Any = {
    if ( debugger.debug ) {
      debugger.msg("Function is: "+f)
      debugger.msg("There are "+functionParameters.length+" parameters.")
      debugger.msg("<<"+functionParameters.mkString(">>\n<<")+">>")
    }
    functionResult = functionMap(f).invoke(f, functionParameters.map(_.asInstanceOf[AnyRef]): _*)
  }

  private def readString() = Helper.readString(in)
  private def writeString(string: String): Unit = Helper.writeString(out,string)

  private def writeAll(buffer: StringBuffer): Unit = {
    buffer.append(baosOut.toString)
    buffer.append(baosErr.toString)
    baosOut.reset
    baosErr.reset
    writeString(buffer.toString)
  }

  private def setAVM(identifier: String, t: String, v: Any): Unit = {
    if ( debugger.debug ) debugger.msg("Value is "+v)
    if ( identifier == "." ) functionParameters.append(v)
    else repl.bind(identifier,t,v)
  }

  private val sep = sys.props("line.separator")

  private def eval(snippet: String): Unit = {
    repl.interpret(snippet)
    //snippet.split(sep).foreach(repl.eval)  // Breaking it line-by-line helps hotspot compiler.
  }

  private def doGC(): Unit = {
    if ( debugger.debug ) debugger.msg("Garbage collection")
    val length = in.readInt()
    if ( debugger.debug ) debugger.msg("... of length: "+length)
    for ( i <- 0 until length ) {
      cacheMap(in.readInt()) = null
    }
  }

  private def doEval(): Unit = {
    val snippet = readString()
    val buffer = new StringBuffer()
    try {
      withOut(psOut) {
        withErr(psErr) {
          eval(snippet)
        }
      }
      R.exit()
      if ( debugger.debug ) debugger.msg("Eval is okay")
      out.writeInt(OK)
    } catch {
      case e: Throwable =>
        R.exit()
        out.writeInt(ERROR)
        val errors = new StringWriter()
        e.printStackTrace(new PrintWriter(errors))
        buffer.append(errors.toString)
    }
    writeAll(buffer)
    out.flush()
  }

  private def doDef(): Unit = {
    val signature = readString().trim
    val body = readString()
    if ( debugger.debug ) debugger.msg("Got signature and body.")
    val buffer = new StringBuffer()
    try {
      val params = if ( signature != "" ) {
        val y = signature.split(":").map(z => {
          val l = z.lastIndexOf(",")
          if ( l >= 0 ) Array(z.substring(0,l).trim,z.substring(l+1).trim)
          else Array(z.trim)
        }).flatten
        val l = y.length/2
        val r = new Array[String](l)
        val s = new Array[String](l)
        for ( i <- 0 until l ) {
          r(i) = y(2*i)
          s(i) = y(2*i+1)
        }
        (r,s)
      } else (Array[String](),Array[String]())
      if ( debugger.debug ) debugger.msg("Parsed signature.")
      out.writeInt(OK)
      try {
        val paramNames = params._1
        val paramTypes = params._2
        out.writeInt(OK)
        try {
          withOut(psOut) {
            withErr(psErr) {
              eval(s"($signature) => { $body }")
            }
          }
          out.writeInt(OK)
          val functionName = repl.mostRecentVar
          if ( debugger.debug ) debugger.msg("Name of function is: <"+functionName+">")
          try {
            storeFunction(repl.valueOfTerm(functionName).get)
            if ( debugger.debug ) debugger.msg("Stored function.")
            out.writeInt(OK)
            writeString(functionName)
            if ( debugger.debug ) debugger.msg("Everything is okay in 'def'.")
            out.writeInt(paramNames.length)
            if ( debugger.debug ) debugger.msg("There are "+paramNames.length+" parameters.")
            paramNames.foreach(writeString)
            paramTypes.foreach(writeString)
            if ( debugger.debug ) debugger.msg("Done.")
          } catch {
            case e: Throwable =>
              out.writeInt(ERROR)
              val errors = new StringWriter()
              e.printStackTrace(new PrintWriter(errors))
              buffer.append(errors.toString)
          }
        } catch {
          case e: Throwable =>
            out.writeInt(ERROR)
            val errors = new StringWriter()
            e.printStackTrace(new PrintWriter(errors))
            buffer.append(errors.toString)
        }
      } catch {
        case e: Throwable =>
          out.writeInt(ERROR)
      }
    } catch {
      case e: Throwable =>
        out.writeInt(ERROR)
    }
    writeAll(buffer)
    out.flush()
  }

  private def doInvoke(): Unit = {
    val functionName = readString()
    val buffer = new StringBuffer()
    try {
      val f = repl.valueOfTerm(functionName).get
      withOut(psOut) {
        withErr(psErr) {
          callFunction(f)
        }
      }
      R.exit()
      if ( debugger.debug ) debugger.msg("Invoke is okay")
      out.writeInt(OK)
    } catch {
      case e: Throwable =>
        R.exit()
        out.writeInt(ERROR)
        val errors = new StringWriter()
        e.printStackTrace(new PrintWriter(errors))
        buffer.append(errors.toString)
    }
    writeAll(buffer)
    out.flush()
  }

  private def doSet(): Unit = {
    val identifier = readString()
    val buffer = new StringBuffer()
    in.readInt() match {
      case NULLTYPE =>
        if ( debugger.debug ) debugger.msg("Setting null.")
        if ( identifier == "." ) functionParameters.append(null)
        else repl.bind(identifier,"Any",null)
      case REFERENCE =>
        if ( debugger.debug ) debugger.msg("Setting reference.")
        val originalIdentifier = readString()
        try {
          if ( identifier == "." ) originalIdentifier match {
            case cacheMapExtractor(i) =>
              functionParameters.append(cacheMap(i.toInt))
            case _ =>
              functionParameters.append(repl.valueOfTerm(originalIdentifier).get)
          } else {
            val r = identifier.split(":") match {
              case Array(n)   => (n.trim,"")
              case Array(n,t) => (n.trim,t.trim)
            }
            originalIdentifier match {
              case cacheMapExtractor(i) =>
                val t = if ( r._2 == "" ) "Any" else r._2
                repl.bind(r._1,t,cacheMap(i.toInt))
              case _ =>
                val vt = if ( r._2 == "" ) r._1 else r._1 + ": " + r._2
                withOut(psOut) {
                  withErr(psErr) {
                    eval(s"val ${vt} = ${originalIdentifier}")
                  }
                }
            }
          }
          out.writeInt(OK)
        } catch {
          case e: Throwable =>
            if ( debugger.debug ) debugger.msg("Caught exception: "+e)
            out.writeInt(ERROR)
            val errors = new StringWriter()
            e.printStackTrace(new PrintWriter(errors))
            buffer.append(errors.toString)
        }
      case ATOMIC =>
        if ( debugger.debug ) debugger.msg("Setting atomic.")
        val (v,t) = in.readInt() match {
          case INTEGER => (in.readInt(),"Int")
          case DOUBLE => (in.readDouble(),"Double")
          case BOOLEAN => (( in.readInt() != 0 ),"Boolean")
          case STRING => (readString(),"String")
          case _ => throw new RuntimeException("Protocol error")
        }
        setAVM(identifier,t,v)
      case VECTOR =>
        if ( debugger.debug ) debugger.msg("Setting vector...")
        val length = in.readInt()
        if ( debugger.debug ) debugger.msg("... of length: "+length)
        val (v,t) = in.readInt() match {
          case INTEGER => (Array.fill(length) { in.readInt() },"Array[Int]")
          case DOUBLE => (Array.fill(length) { in.readDouble() },"Array[Double]")
          case BOOLEAN => (Array.fill(length) { ( in.readInt() != 0 ) },"Array[Boolean]")
          case STRING => (Array.fill(length) { readString() },"Array[String]")
          case _ => throw new RuntimeException("Protocol error")
        }
        setAVM(identifier,t,v)
      case MATRIX =>
        if ( debugger.debug ) debugger.msg("Setting matrix...")
        val nrow = in.readInt()
        val ncol = in.readInt()
        if ( debugger.debug ) debugger.msg("... of dimensions: "+nrow+","+ncol)
        val (v,t) = in.readInt() match {
          case INTEGER => (Array.fill(nrow) { Array.fill(ncol) { in.readInt() } },"Array[Array[Int]]")
          case DOUBLE => (Array.fill(nrow) { Array.fill(ncol) { in.readDouble() } },"Array[Array[Double]]")
          case BOOLEAN => (Array.fill(nrow) { Array.fill(ncol) { ( in.readInt() != 0 ) } },"Array[Array[Boolean]]")
          case STRING => (Array.fill(nrow) { Array.fill(ncol) { readString() } },"Array[Array[String]]")
          case _ => throw new RuntimeException("Protocol error")
        }
        setAVM(identifier,t,v)
      case _ => throw new RuntimeException("Protocol error")
    }
    writeAll(buffer)
    out.flush()
  }

  private def doGet(): Unit = {
    val identifier = readString()
    if ( debugger.debug ) debugger.msg("Trying to get value of: "+identifier)
    val option = try {
      identifier match {
        case "." => 
          repl.valueOfTerm(repl.mostRecentVar)
        case "?" => 
          val s = Some(functionResult)
          functionResult = null
          s
        case cacheMapExtractor(i) =>
          Some(cacheMap(i.toInt))
        case _ =>
          repl.valueOfTerm(identifier)
      }
    } catch {
      case e: Throwable =>
        if ( debugger.debug ) debugger.msg("Caught exception: "+e)
        out.writeInt(ERROR)
        val errors = new StringWriter()
        e.printStackTrace(new PrintWriter(errors))
        val buffer = new StringBuffer()
        buffer.append(errors.toString)
        writeAll(buffer)
        out.flush()
        return
    }
    if ( option.isEmpty ) {
      out.writeInt(UNDEFINED_IDENTIFIER)
      out.flush()
      return
    }
    val v = option.get
    if ( debugger.debug ) debugger.msg("Getting: "+identifier)
    if ( v == null || v.isInstanceOf[Unit] ) {
      if ( debugger.debug ) debugger.msg("... which is null")
      out.writeInt(NULLTYPE)
      out.flush()
      return
    }
    val c = v.getClass
    if ( debugger.debug ) debugger.msg("... whose class is: "+c)
    if ( debugger.debug ) debugger.msg("... and whose value is: "+v)
    if ( c.isArray ) {
      c.getName match {
        case "[I" =>
          val vv = v.asInstanceOf[Array[Int]]
          out.writeInt(VECTOR)
          out.writeInt(vv.length)
          out.writeInt(INTEGER)
          for ( i <- 0 until vv.length ) out.writeInt(vv(i))
        case "[D" =>
          val vv = v.asInstanceOf[Array[Double]]
          out.writeInt(VECTOR)
          out.writeInt(vv.length)
          out.writeInt(DOUBLE)
          for ( i <- 0 until vv.length ) out.writeDouble(vv(i))
        case "[Z" =>
          val vv = v.asInstanceOf[Array[Boolean]]
          out.writeInt(VECTOR)
          out.writeInt(vv.length)
          out.writeInt(BOOLEAN)
          for ( i <- 0 until vv.length ) out.writeInt(if ( vv(i) ) 1 else 0)
        case "[Ljava.lang.String;" =>
          val vv = v.asInstanceOf[Array[String]]
          out.writeInt(VECTOR)
          out.writeInt(vv.length)
          out.writeInt(STRING)
          for ( i <- 0 until vv.length ) writeString(vv(i))
        case "[[I" =>
          val vv = v.asInstanceOf[Array[Array[Int]]]
          if ( Helper.isMatrix(vv) ) {
            out.writeInt(MATRIX)
            out.writeInt(vv.length)
            if ( vv.length > 0 ) out.writeInt(vv(0).length)
            else out.writeInt(0)
            out.writeInt(INTEGER)
            for ( i <- 0 until vv.length ) {
              val vvv = vv(i)
              for ( j <- 0 until vvv.length ) {
                out.writeInt(vv(i)(j))
              }
            }
          }
        case "[[D" =>
          val vv = v.asInstanceOf[Array[Array[Double]]]
          if ( Helper.isMatrix(vv) ) {
            out.writeInt(MATRIX)
            out.writeInt(vv.length)
            if ( vv.length > 0 ) out.writeInt(vv(0).length)
            else out.writeInt(0)
            out.writeInt(DOUBLE)
            for ( i <- 0 until vv.length ) {
              val vvv = vv(i)
              for ( j <- 0 until vvv.length ) {
                out.writeDouble(vvv(j))
              }
            }
          } else out.writeInt(UNSUPPORTED_STRUCTURE)
        case "[[Z" =>
          val vv = v.asInstanceOf[Array[Array[Boolean]]]
          if ( Helper.isMatrix(vv) ) {
            out.writeInt(MATRIX)
            out.writeInt(vv.length)
            if ( vv.length > 0 ) out.writeInt(vv(0).length)
            else out.writeInt(0)
            out.writeInt(BOOLEAN)
            for ( i <- 0 until vv.length ) {
              val vvv = vv(i)
              for ( j <- 0 until vv(i).length ) {
                out.writeInt(if ( vvv(j) ) 1 else 0)
              }
            }
          } else out.writeInt(UNSUPPORTED_STRUCTURE)
        case "[[Ljava.lang.String;" =>
          val vv = v.asInstanceOf[Array[Array[String]]]
          if ( Helper.isMatrix(vv) ) {
            out.writeInt(MATRIX)
            out.writeInt(vv.length)
            if ( vv.length > 0 ) out.writeInt(vv(0).length)
            else out.writeInt(0)
            out.writeInt(STRING)
            for ( i <- 0 until vv.length ) {
              val vvv = vv(i)
              for ( j <- 0 until vv(i).length ) {
                writeString(vvv(j))
              }
            }
          } else out.writeInt(UNSUPPORTED_STRUCTURE)
        case _ =>
          if ( debugger.debug ) debugger.msg("Bowing out: "+identifier)
          out.writeInt(UNSUPPORTED_STRUCTURE)
      }
    } else {
      c.getName match {
        case "java.lang.Integer" =>
          out.writeInt(ATOMIC)
          out.writeInt(INTEGER)
          out.writeInt(v.asInstanceOf[Int])
        case "java.lang.Double" =>
          out.writeInt(ATOMIC)
          out.writeInt(DOUBLE)
          out.writeDouble(v.asInstanceOf[Double])
        case "java.lang.Boolean" =>
          out.writeInt(ATOMIC)
          out.writeInt(BOOLEAN)
          out.writeInt(if (v.asInstanceOf[Boolean]) 1 else 0)
        case "java.lang.String" =>
          out.writeInt(ATOMIC)
          out.writeInt(STRING)
          writeString(v.asInstanceOf[String])
        case _ =>
          if ( debugger.debug ) debugger.msg("Bowing out: "+identifier)
          out.writeInt(UNSUPPORTED_STRUCTURE)
      }
    }
    out.flush()
  }

  private def doGetReference(): Unit = {
    val identifier = readString()
    if ( debugger.debug ) debugger.msg("Trying to get reference for: "+identifier)
    identifier match {
      case "?" =>
        cacheMap.append(functionResult)
        out.writeInt(OK)
        writeString("." + (cacheMap.length-1))
        writeString(functionResult.getClass.getName)
      case cacheMapExtractor(i) =>
        out.writeInt(OK)
        writeString(identifier)
        val r = cacheMap(i.toInt)
        writeString(r.getClass.getName)
      case _ =>
        val id = if ( identifier == "." ) repl.mostRecentVar else identifier
        val option = try {
          repl.valueOfTerm(id)
        } catch {
          case e: Throwable =>
            if ( debugger.debug ) debugger.msg("Caught exception: "+e)
            out.writeInt(ERROR)
            val errors = new StringWriter()
            e.printStackTrace(new PrintWriter(errors))
            val buffer = new StringBuffer()
            buffer.append(errors.toString)
            writeAll(buffer)
            out.flush()
            return
        }
        if ( option.isEmpty ) out.writeInt(UNDEFINED_IDENTIFIER)
        else {
          out.writeInt(OK)
          writeString(id)
          writeString(repl.typeOfTerm(id).toString)
        }
    }
    out.flush()
  }

  @tailrec
  final def run(): Unit = {
    if ( debugger.debug ) debugger.msg("Top of the loop waiting for a command.")
    val request = try {
      in.readInt()
    } catch {
      case _: Throwable =>
        return
    }
    if ( debugger.debug ) debugger.msg("Received request: "+request)
    request match {
      case EXIT =>
        if ( debugger.debug ) debugger.msg("Exiting")
        in.close()
        out.close()
        socketIn.close()
        socketOut.close()
        return
      case RESET =>
        if ( debugger.debug ) debugger.msg("Resetting")
        cacheMap.clear()
      case GC =>
        doGC()
      case DEBUG =>
        val newDebug = ( in.readInt() != 0 )
        if ( debugger.debug != newDebug ) debugger.msg("Debugging is now "+newDebug)
        debugger.debug = newDebug
      case EVAL =>
        doEval()
      case SET =>
        doSet()
      case GET =>
        doGet()
      case GET_REFERENCE =>
        doGetReference()
      case DEF =>
        doDef()
      case CLEAR =>
        functionParameters.clear
      case INVOKE =>
        doInvoke()
    }
    run()
  }

}

object ScalaServer {

  def apply(repl: IMain, portsFilename: String, debug: Boolean = false): ScalaServer = new ScalaServer(repl, portsFilename, new Debugger(debug))

}

