import tensorflow as tf
from tensorflow.contrib.tensorboard.plugins import projector
from tensorflow.python.tools import freeze_graph
import time
import re
import os
import tempfile
import shutil
import codecs
import sys


from utils import XESAttributeTypes
from TFGraphBuilder import TFGraphBuilder
from TrainingProgressBox import TrainingProgressBox


class TFGraphTrainer:
    def __init__(self, tkRoot, graphConfig, XESData):

        self.tkRoot = tkRoot
        self.graphConfig = graphConfig
        self.tfGraph = None
        self.data = XESData
        self.summaryWriter = None

        self.progressBox = None
        self.foldNum = 0
        self.epochNum = 0
        self.currentstep = 0

        self.session = None

    def runEpoch(self, train, priorsteps):
        currentstep = 0

        if train:
            self.progressBox.updateMessage("Training   fold {0:2d}, epoch {1:3d}".format(self.foldNum, self.epochNum))
        else:
            self.progressBox.updateMessage("Validation fold {0:2d}, epoch {1:3d}".format(self.foldNum, self.epochNum))

        # initialize the return dicts
        sumLossArray = dict()
        sumMeanCorrectPredictionsArray = dict()
        for key in sorted(self.tfGraph.loss):
            sumLossArray[key] = 0.0
            sumMeanCorrectPredictionsArray[key] = 0.0

        # Set the learning rate for this epoch
        epochLearningRate = self.graphConfig.baseLearningRate*self.graphConfig.lrDecay**max(self.epochNum-self.graphConfig.numEpochsFullLR, 0.0)
        _ = self.session.run(fetches=tf.assign(self.tfGraph.learningRate, epochLearningRate))

        if not self.graphConfig.sharedRNN:
            # Fetch the initial states and deconstruct the result into a dict by target name
            stateValue = dict()
            runResults = self.session.run(fetches=[self.tfGraph.initialState[key] for key in sorted(self.tfGraph.initialState)])
            i = 0
            for targName in sorted(self.tfGraph.initialState):
                stateValue[targName] = runResults[i]
                i += 1
        else:
            # Fetch the initial states and deconstruct the result into a dict by target name
            stateValue = self.session.run(fetches=[self.tfGraph.initialState])

        startTime = time.time()
        # The main loop for the epoch, until getFeedDict raises the "no more data" exception
        while True:
            try:
                # get the data into the feeddict
                feeddict = self.data.getFeedDict(batchSize=self.graphConfig.batchSize, numUnrollSteps=self.graphConfig.numUnrollSteps, startstep=currentstep, train=train, tfGraph=self.tfGraph)

                # add the initial states into the feeddict
                if not self.graphConfig.sharedRNN:
                    for (attrname, stateTensor) in self.tfGraph.initialState.iteritems():
                        feeddict[stateTensor] = stateValue[attrname]

                    # and run the session with the fetches and feeddict
                    if train:
                        runResults = self.session.run(fetches=[self.tfGraph.mergedSummaries]
                                                      + [self.tfGraph.meanCorrectPredictions[key] for key in sorted(self.tfGraph.meanCorrectPredictions)]
                                                      + [self.tfGraph.loss[key] for key in sorted(self.tfGraph.loss)]
                                                      + [self.tfGraph.trainOp[key] for key in sorted(self.tfGraph.trainOp)]
                                                      + [self.tfGraph.correct_predictions[key] for key in sorted(self.tfGraph.correct_predictions)]
                                                      + [self.tfGraph.oneHotTargets[key] for key in sorted(self.tfGraph.oneHotTargets)]
                                                      + [self.tfGraph.softMax[key] for key in sorted(self.tfGraph.softMax)]
                                                      + [self.tfGraph.states[key] for key in sorted(self.tfGraph.states)],
                                                      feed_dict=feeddict)
                    else:
                        runResults = self.session.run(fetches=[self.tfGraph.mergedSummaries]
                                                      + [self.tfGraph.meanCorrectPredictions[key] for key in sorted(self.tfGraph.meanCorrectPredictions)]
                                                      + [self.tfGraph.loss[key] for key in sorted(self.tfGraph.loss)]
                                                      + [self.tfGraph.correct_predictions[key] for key in sorted(self.tfGraph.correct_predictions)]
                                                      + [self.tfGraph.oneHotTargets[key] for key in sorted(self.tfGraph.oneHotTargets)]
                                                      + [self.tfGraph.softMax[key] for key in sorted(self.tfGraph.softMax)]
                                                      + [self.tfGraph.states[key] for key in sorted(self.tfGraph.states)],
                                                      feed_dict=feeddict)
                else:
                    feeddict[self.tfGraph.initialState] = stateValue

                    # and run the session with the fetches and feeddict
                    if train:
                        runResults = self.session.run(fetches=[self.tfGraph.mergedSummaries]
                                                      + [self.tfGraph.meanCorrectPredictions[key] for key in sorted(self.tfGraph.meanCorrectPredictions)]
                                                      + [self.tfGraph.loss[key] for key in sorted(self.tfGraph.loss)]
                                                      + [self.tfGraph.trainOp[key] for key in sorted(self.tfGraph.trainOp)]
                                                      + [self.tfGraph.correct_predictions[key] for key in sorted(self.tfGraph.correct_predictions)]
                                                      + [self.tfGraph.oneHotTargets[key] for key in sorted(self.tfGraph.oneHotTargets)]
                                                      + [self.tfGraph.softMax[key] for key in sorted(self.tfGraph.softMax)]
                                                      + [self.tfGraph.states],
                                                      feed_dict=feeddict)
                    else:
                        runResults = self.session.run(fetches=[self.tfGraph.mergedSummaries]
                                                      + [self.tfGraph.meanCorrectPredictions[key] for key in sorted(self.tfGraph.meanCorrectPredictions)]
                                                      + [self.tfGraph.loss[key] for key in sorted(self.tfGraph.loss)]
                                                      + [self.tfGraph.correct_predictions[key] for key in sorted(self.tfGraph.correct_predictions)]
                                                      + [self.tfGraph.oneHotTargets[key] for key in sorted(self.tfGraph.oneHotTargets)]
                                                      + [self.tfGraph.softMax[key] for key in sorted(self.tfGraph.softMax)]
                                                      + [self.tfGraph.states],
                                                      feed_dict=feeddict)

                # deconstruct the runResults list
                meancorrectPredictionsArray = dict()
                lossArray = dict()
                trainOpArray = dict()
                correct_predictionsArray = dict()
                oneTargetsArray = dict()
                softMaxArray = dict()
                i = 0
                mergedSummariesValue = runResults[i]
                if train:
                    self.summaryWriter.add_summary(mergedSummariesValue, priorsteps+currentstep)
                i += 1

                for targName in sorted(self.tfGraph.meanCorrectPredictions):
                    meancorrectPredictionsArray[targName] = runResults[i]
                    sumMeanCorrectPredictionsArray[targName] += runResults[i]
                    i += 1
                for targName in sorted(self.tfGraph.loss):
                    lossArray[targName] = runResults[i]
                    sumLossArray[targName] += runResults[i]
                    i += 1
                if train:
                    for targName in sorted(self.tfGraph.trainOp):
                        trainOpArray[targName] = runResults[i]
                        i += 1
                for targName in sorted(self.tfGraph.correct_predictions):
                    correct_predictionsArray[targName] = runResults[i]
                    i += 1
                for targName in sorted(self.tfGraph.oneHotTargets):
                    oneTargetsArray[targName] = runResults[i]
                    i += 1
                for targName in sorted(self.tfGraph.softMax):
                    softMaxArray[targName] = runResults[i]
                    i += 1
                if not self.graphConfig.sharedRNN:
                    for targName in sorted(self.tfGraph.states):
                        stateValue[targName] = runResults[i]
                        i += 1
                else:
                    stateValue = runResults[i]

                currentstep += self.graphConfig.numUnrollSteps
                self.progressBox.makeTick()

            except StopIteration:
                # No more data, finished with epoch, return results
                endTime = time.time()
                stepsPerSecond = currentstep*self.graphConfig.batchSize/(endTime-startTime)

                # Compute the means over the arrays containing loss and correctpredictions
                sumMeanCorrectPredictionsArray = dict([(name, 100 * val / (float(currentstep) / self.graphConfig.numUnrollSteps)) for (name, val) in sumMeanCorrectPredictionsArray.iteritems()])
                sumLossArray = dict([(name, val/(float(currentstep) / self.graphConfig.numUnrollSteps)) for (name, val) in sumLossArray.iteritems()])
                # Update the progress box
                if train:
                    self.progressBox.update("Training   fold {0:2d}, epoch {1:3d}".format(self.foldNum, self.epochNum),
                                            stepsPerSecond,
                                            epochLearningRate,
                                            sumMeanCorrectPredictionsArray,
                                            sumLossArray,
                                            )
                else:
                    self.progressBox.update("Validation fold {0:2d}, epoch {1:3d}".format(self.foldNum, self.epochNum),
                                            stepsPerSecond,
                                            epochLearningRate,
                                            sumMeanCorrectPredictionsArray,
                                            sumLossArray,
                                            )
                # and return the results
                return sumMeanCorrectPredictionsArray, sumLossArray, currentstep

    def initResultsFile(self):

        resultFile = open(os.path.join(self.graphConfig.logDirPath, os.path.basename(self.graphConfig.xesPath)+"TrainingValidationResultFile.csv"), 'w')

        resultFile.write("XESFile, %s\n" % self.graphConfig.xesPath)
        resultFile.write("numFolds, %0.d\n" % self.graphConfig.numFolds)
        resultFile.write("Batchsize, %0.d\n" % self.graphConfig.batchSize)
        resultFile.write("numUnrollSteps, %0.d\n" % self.graphConfig.numUnrollSteps)
        resultFile.write("droputProb, %.3f\n" % self.graphConfig.dropoutProb)
        resultFile.write("numLayers, %0.d\n" % self.graphConfig.numLayers)
        resultFile.write("maxGradNorm, %.3f\n" % self.graphConfig.maxGradNorm)
        resultFile.write("initScale, %.3f\n" % self.graphConfig.initScale)
        resultFile.write("numEpochsFullLR, %0.d\n" % self.graphConfig.numEpochsFullLR)
        resultFile.write("numEpochs, %0.d\n" % self.graphConfig.numEpochs)
        resultFile.write("baseLearningRate, %.3f\n" % self.graphConfig.baseLearningRate)
        resultFile.write("lrDecay, %.3f\n" % self.graphConfig.lrDecay)
        resultFile.write("forgetBias, %.3f\n" % self.graphConfig.forgetBias)
        resultFile.write("numLossFunc, %s\n" % self.graphConfig.numLossFunc)
        resultFile.write("rnnActivationFunc, %s\n" % self.graphConfig.rnnActivationFunc)
        resultFile.write("dateScale, %s\n" % self.graphConfig.dateScale)
        resultFile.write("Optimizer, %s\n" % self.graphConfig.optimizer)
        resultFile.write("OptimPar1, %s\n" % self.graphConfig.optimPar1)
        resultFile.write("OptimPar2, %s\n" % self.graphConfig.optimPar2)
        if self.graphConfig.usePeepHoles:
            resultFile.write("Peepholes\n")
        else:
            resultFile.write("NoPeepholes\n")
        if self.graphConfig.autoRNNSize:
            resultFile.write("AutoRNNSize\n")
        else:
            resultFile.write("RNNSize, %0.d\n" % self.graphConfig.rnnSize)
        if self.graphConfig.sharedRNN:
            resultFile.write("SharedRNN\n")

        resultFile.write("\nFold, Epoch, Target, TrainSteps, TrainPrecision, TrainLoss, ValidSteps, ValidPrecision, ValidLoss\n")
        resultFile.flush()
        os.fsync(resultFile.fileno())
        return resultFile

    def writeEmbeddingMetaData(self):
        for (attrnames, attrtype, embedSize, val2id, id2val, _, eventattr) in self.graphConfig.preds:
            if attrtype == XESAttributeTypes.CATEGORICAL and (attrnames, eventattr) in self.graphConfig.predictors:
                attrnamesString = re.sub("[^A-Za-z0-9.\\-]*", "", "-".join(attrnames))+"T" if eventattr else re.sub("[^A-Za-z0-9.\\-]*", "", "-".join(attrnames))+"F"
                metaDataFile = codecs.open(os.path.join(self.graphConfig.logDirPath, attrnamesString+".EmbeddingMeta.txt"), encoding="utf-8", mode="w")
                metaDataFile.write("\n".join([id2val[k] for k in sorted(id2val)]))
                metaDataFile.close()

    def writeEmbeddings(self, globalsteps):
        # Write a snapshot/checkpoint of the graph
        graphSaver = tf.train.Saver(var_list=self.tfGraph.embeddings.values())
        graphSaver.save(sess=self.session,
                        save_path=os.path.join(self.graphConfig.logDirPath, "TensorBoard-Fold{0}".format(self.foldNum), "EmbeddingCheckpoint.Fold{0}".format(self.foldNum)),
                        global_step=globalsteps)
        # Write out the embedding configuration for TensorBoard
        config = projector.ProjectorConfig()
        for ((attrnames, eventattr), embeddingTensor) in self.tfGraph.embeddings.iteritems():
            attrnamesString = attrnames + "T" if eventattr else attrnames + "F"
            embedding = config.embeddings.add()
            embedding.tensor_name = embeddingTensor.name
            embedding.metadata_path = os.path.join(self.graphConfig.logDirPath, attrnamesString + ".EmbeddingMeta.txt")
        # Saves a configuration file that TensorBoard will read during startup.
        projector.visualize_embeddings(self.summaryWriter, config)

    def saveModel(self, globalsteps):
        varList = self.tfGraph.g.get_collection("variables") + self.tfGraph.g.get_collection("saveable_objects")
        tmpDir = tempfile.mkdtemp()

        graphSaver = tf.train.Saver(var_list=varList)
        saveFilenameStem = "ModelSaverCheckpointFold{0}".format(self.foldNum)
        savepath = os.path.join(tmpDir, saveFilenameStem)
        checkpointPath = graphSaver.save(sess=self.session,
                                         save_path=savepath,
                                         meta_graph_suffix='metagraph',
                                         latest_filename="ModelSaverCheckpointFold{0}".format(self.foldNum),
                                         global_step=globalsteps)
        filenameTensorName = graphSaver.as_saver_def().filename_tensor_name
        restoreOpName = graphSaver.as_saver_def().restore_op_name

        inputGraphFilename = "inputGraph.ProtoBuf"
        outputGraphFilename = "TrainedModelFold{0}.ProtoBuf".format(self.foldNum)
        tf.train.write_graph(self.session.graph, tmpDir, inputGraphFilename)

        outputNodes = []
        for (attrnames, attrtype, _, _, _, _, eventAttrib) in self.graphConfig.preds:
            if (attrnames, eventAttrib) in self.graphConfig.targets:
                attrnames = re.sub("[^A-Za-z0-9.\\-]*", "", "-".join(attrnames))
                if not self.graphConfig.sharedRNN:
                    outputNodes.append("LSTM/"+attrnames+"/FinalStateTensor")
                else:
                    outputNodes.append("LSTM/FinalStateTensor")
                if attrtype == XESAttributeTypes.DATE or attrtype == XESAttributeTypes.NUMERIC:
                    outputNodes.append("target/"+attrnames+"/outputConcat")
                if attrtype == XESAttributeTypes.CATEGORICAL:
                    outputNodes.append("target/"+attrnames+"/softMax")

        freeze_graph.freeze_graph(input_graph=os.path.join(tmpDir, inputGraphFilename),
                                  input_saver="",
                                  input_binary=False,
                                  input_checkpoint=checkpointPath,
                                  output_node_names=",".join(outputNodes),
                                  restore_op_name=restoreOpName,
                                  filename_tensor_name=filenameTensorName,
                                  output_graph=os.path.join(self.graphConfig.logDirPath, outputGraphFilename),
                                  clear_devices=False,
                                  initializer_nodes="")

        shutil.rmtree(tmpDir, ignore_errors=True)

    def train(self):
        self.progressBox = TrainingProgressBox(targNames=[re.sub("[^A-Za-z0-9.\\-]*", "", "-".join(attrnames)) for (attrnames, _) in self.graphConfig.targets], xespath=self.graphConfig.xesPath)

        self.progressBox.updateMessage("Creating Data Folds")
        self.data.createDataFolds(self.graphConfig.numFolds)
        graphBuilder = TFGraphBuilder(self.graphConfig, self.progressBox)

        self.progressBox.updateMessage("Writing Meta Data")
        resultsFile = self.initResultsFile()
        self.writeEmbeddingMetaData()

        for self.foldNum in range(0, self.graphConfig.numFolds):

            self.progressBox.updateMessage("Building Graph, Fold {0}".format(self.foldNum))
            # Build a new graph
            self.tfGraph = graphBuilder.build()
            # Start a session with this graph
            self.session = tf.Session(graph=self.tfGraph.g)
            # Initialize all variables
            self.session.run(fetches=self.tfGraph.initOp)

            self.progressBox.updateMessage("Creating Training and Validation Data, Fold {0}".format(self.foldNum))
            # Create training and validation data
            self.data.createTrainValidData(validFold=self.foldNum)
            # keep track of the number of **training** steps (events) across all epochs for logging to TensorBoard (We don't log the validation steps)

            globalsteps = 0
            # create a summary writer for this fold
            self.summaryWriter = tf.summary.FileWriter(logdir=os.path.join(self.graphConfig.logDirPath, "TensorBoard-Fold{0}".format(self.foldNum)), graph=self.tfGraph.g)
            # run the specified number of epochs
            for self.epochNum in range(0, self.graphConfig.numEpochs):
                # run a training epoch
                trainCorrectPredictionsArray, trainLossArray, trainSteps = self.runEpoch(train=True, priorsteps=globalsteps)
                globalsteps += trainSteps
                # if the user wants it every epoch or if we're on the last epoch
                if self.graphConfig.validateEpoch or self.epochNum == self.graphConfig.numEpochs-1:
                    # run a validation epoch
                    validateCorrectPredictionsArray, validateLossArray, validSteps = self.runEpoch(train=False, priorsteps=globalsteps)
                    # write the results to file
                    for targName in sorted(trainCorrectPredictionsArray):
                        resultsFile.write("%d, %d, %s, %d, %.3f, %.6f, %d, %.3f, %.6f\n" % (self.foldNum,
                                                                                            self.epochNum,
                                                                                            targName,
                                                                                            trainSteps,
                                                                                            trainCorrectPredictionsArray[targName],
                                                                                            trainLossArray[targName],
                                                                                            validSteps,
                                                                                            validateCorrectPredictionsArray[targName],
                                                                                            validateLossArray[targName]))
                        resultsFile.flush()
                        os.fsync(resultsFile.fileno())

            # After each fold is completed
            # Write a snapshot/checkpoint of the embeddings in the graph
            self.progressBox.updateMessage("Writing Embeddings, Fold {0}".format(self.foldNum))
            self.writeEmbeddings(globalsteps)
            # Write the entire model, but stripped for predictions (remove any learning related stuff)
            self.progressBox.updateMessage("Writing Trained Model, Fold {0}".format(self.foldNum))
            try:
                self.saveModel(globalsteps)
            except Exception as e:
                sys.stderr.write("Error saving model\n" + str(e))
            # close the summary writer
            self.summaryWriter.close()

        self.progressBox.destroy()
