import tensorflow as tf
import re
import numpy

from XESParserPrediction import XESParserPrediction
from XESPredictionData import XESPredictionData
from TFGraphBuilder import TFGraph
from utils import XESAttributeTypes


class TFGraphPredictor:

    def readData(self, xesPath):

        parser = XESParserPrediction(xesPath, self.graphConfig.batchSize, self.graphConfig.numUnrollSteps)
        parser.read()

        data = XESPredictionData(parser)
        data.finalizeData(self.graphConfig)

        return data

    def readGraph(self, graphProtoBufPath):
        self.tfGraph = TFGraph()

        with self.tfGraph.g.as_default():
            graph_def = tf.GraphDef()
            with open(graphProtoBufPath, "rb") as f:
                graph_def.ParseFromString(f.read())
                _ = tf.import_graph_def(graph_def, name="")

            self.session = tf.Session(graph=self.tfGraph.g)

            for (attrnames, eventAttr) in self.graphConfig.targets:
                attrtype, _ = self.data.getAttrType(self.graphConfig.preds, attrnames, eventAttr)
                attrnames = re.sub("[^A-Za-z0-9.\\-]*", "", "-".join(attrnames))

                self.tfGraph.states[attrnames] = self.session.graph.get_tensor_by_name("LSTM/" + attrnames + "/FinalStateTensor:0")
                if attrtype == XESAttributeTypes.CATEGORICAL:
                    self.tfGraph.softMax[attrnames] = self.session.graph.get_tensor_by_name("target/" + attrnames + "/softMax:0")
                else:
                    self.tfGraph.nnoutputs_concat[attrnames] = self.session.graph.get_tensor_by_name("target/" + attrnames + "/outputConcat:0")
                if not self.graphConfig.sharedRNN:
                    states = []
                    for layer in range(0, self.graphConfig.numLayers):
                        if layer == 0:
                            c = self.session.graph.get_tensor_by_name("LSTM/" + attrnames + "/zeros:0")
                            h = self.session.graph.get_tensor_by_name("LSTM/" + attrnames + "/zeros_1:0")
                            states.append(tf.nn.rnn_cell.LSTMStateTuple(c=c, h=h))
                        else:
                            c = self.session.graph.get_tensor_by_name("LSTM/" + attrnames + "/zeros_{0}:0".format(layer * 2))
                            h = self.session.graph.get_tensor_by_name("LSTM/" + attrnames + "/zeros_{0}:0".format(layer * 2 + 1))
                            states.append(tf.nn.rnn_cell.LSTMStateTuple(c=c, h=h))
                    self.tfGraph.initialState[attrnames] = tuple(states)
                else:
                    states = []
                    for layer in range(0, self.graphConfig.numLayers):
                        if layer == 0:
                            c = self.session.graph.get_tensor_by_name("LSTM/zeros:0")
                            h = self.session.graph.get_tensor_by_name("LSTM/zeros_1:0")
                            states.append(tf.nn.rnn_cell.LSTMStateTuple(c=c, h=h))
                        else:
                            c = self.session.graph.get_tensor_by_name("LSTM/zeros_{0}:0".format(layer * 2))
                            h = self.session.graph.get_tensor_by_name("LSTM/zeros_{0}:0".format(layer * 2 + 1))
                            states.append(tf.nn.rnn_cell.LSTMStateTuple(c=c, h=h))
                    self.tfGraph.initialState = tuple(states)

            for (attrnames, eventAttr) in self.graphConfig.predictors:
                attrnames = re.sub("[^A-Za-z0-9.\\-]*", "", "-".join(attrnames))
                if eventAttr:
                    self.tfGraph.inputPlaceholders[(attrnames, eventAttr)] = self.session.graph.get_tensor_by_name("predictor/" + attrnames + "Event/inputPlaceholder:0")
                else:
                    self.tfGraph.inputPlaceholders[(attrnames, eventAttr)] = self.session.graph.get_tensor_by_name("predictor/" + attrnames + "Case/inputPlaceholder:0")

    def getZeroState(self):
        if not self.graphConfig.sharedRNN:
            stateValue = dict()
            for (attrnames, eventAttr) in self.graphConfig.targets:
                attrnames = re.sub("[^A-Za-z0-9.\\-]*", "", "-".join(attrnames))
                states = []
                for layer in range(0, self.graphConfig.numLayers):
                    if layer == 0:
                        cTensor = self.session.graph.get_tensor_by_name("LSTM/" + attrnames + "/zeros:0")
                        hTensor = self.session.graph.get_tensor_by_name("LSTM/" + attrnames + "/zeros_1:0")
                    else:
                        cTensor = self.session.graph.get_tensor_by_name("LSTM/" + attrnames + "/zeros_{0}:0".format(layer * 2))
                        hTensor = self.session.graph.get_tensor_by_name("LSTM/" + attrnames + "/zeros_{0}:0".format(layer * 2 + 1))
                    cValue = numpy.zeros(shape=cTensor.get_shape(), dtype=cTensor.dtype.as_numpy_dtype)
                    hValue = numpy.zeros(shape=hTensor.get_shape(), dtype=hTensor.dtype.as_numpy_dtype)
                    states.append(tf.nn.rnn_cell.LSTMStateTuple(c=cValue, h=hValue))
                stateValue[attrnames] = tuple(states)
        else:
            states = []
            for layer in range(0, self.graphConfig.numLayers):
                if layer == 0:
                    cTensor = self.session.graph.get_tensor_by_name("LSTM/zeros:0")
                    hTensor = self.session.graph.get_tensor_by_name("LSTM/zeros_1:0")
                else:
                    cTensor = self.session.graph.get_tensor_by_name("LSTM/zeros_{0}:0".format(layer * 2))
                    hTensor = self.session.graph.get_tensor_by_name("LSTM/zeros_{0}:0".format(layer * 2 + 1))
                cValue = numpy.zeros(shape=cTensor.get_shape(), dtype=cTensor.dtype.as_numpy_dtype)
                hValue = numpy.zeros(shape=hTensor.get_shape(), dtype=hTensor.dtype.as_numpy_dtype)
                states.append(tf.nn.rnn_cell.LSTMStateTuple(c=cValue, h=hValue))
            stateValue = tuple(states)

        return stateValue

    def predict(self, numPSteps, xesOutputPath, eocDetect):
        currentstep = 0
        currentPStep = 0
        stateValue = self.getZeroState()

        # The main loop for the epoch, until getFeedDict raises the "no more data" exception
        while currentPStep < numPSteps:
            # get the data into the feeddict
            feeddict, predIndex = self.data.getFeedDict(numUnrollSteps=self.graphConfig.numUnrollSteps, startstep=currentstep, tfGraph=self.tfGraph, graphConfig=self.graphConfig)

            # add the initial states into the feeddict
            if not self.graphConfig.sharedRNN:
                for (attrname, stateTensor) in self.tfGraph.initialState.iteritems():
                    feeddict[stateTensor] = stateValue[attrname]

                # and run the session with the fetches and feeddict
                runResults = self.session.run(fetches=[self.tfGraph.nnoutputs_concat[key] for key in sorted(self.tfGraph.nnoutputs_concat)]
                                              + [self.tfGraph.softMax[key] for key in sorted(self.tfGraph.softMax)]
                                              + [self.tfGraph.states[key] for key in sorted(self.tfGraph.states)],
                                              feed_dict=feeddict)
            else:
                feeddict[self.tfGraph.initialState] = stateValue
                # and run the session with the fetches and feeddict
                runResults = self.session.run(fetches=[self.tfGraph.nnoutputs_concat[key] for key in sorted(self.tfGraph.nnoutputs_concat)]
                                              + [self.tfGraph.softMax[key] for key in sorted(self.tfGraph.softMax)]
                                              + [self.tfGraph.states],
                                              feed_dict=feeddict)

            # deconstruct the runResults list
            nnoutputsArray = dict()
            softMaxArray = dict()
            finalState = dict()
            i = 0
            for targName in sorted(self.tfGraph.nnoutputs_concat):
                nnoutputsArray[targName] = runResults[i]
                i += 1
            for targName in sorted(self.tfGraph.softMax):
                softMaxArray[targName] = runResults[i]
                i += 1
            if not self.graphConfig.sharedRNN:
                for targName in sorted(self.tfGraph.states):
                    finalState[targName] = self.reconstructState(runResults[i])
                    i += 1
            else:
                finalState = self.reconstructState(runResults[i])

            if predIndex is not None:
                # put it back in the input
                self.data.addPrediction(nnoutputsArray, softMaxArray, predIndex, self.graphConfig)
                currentPStep += 1
            if predIndex is None or predIndex == self.graphConfig.numUnrollSteps-1:
                stateValue = finalState
                currentstep += self.graphConfig.numUnrollSteps

        self.data.writeToXES(xesOutputPath, self.data.prefixLength, numPSteps, self.graphConfig, eocDetect)

    def reconstructState(self, a):
        states = []
        for layer in range(0, self.graphConfig.numLayers):
            cValue = a[layer, 0]
            hValue = a[layer, 1]
            states.append(tf.nn.rnn_cell.LSTMStateTuple(c=cValue, h=hValue))

        return tuple(states)

    def showResult(self):
        pass

    def writeOutput(self):
        pass

    def __init__(self, graphProtoBufPath, graphConfig, xesPath):

        self.tfGraph = None
        self.graphConfig = graphConfig
        self.session = None
        self.data = self.readData(xesPath)
        self.readGraph(graphProtoBufPath)
