import re
import tensorflow as tf

from utils import XESAttributeTypes


class TFGraph:
    def __init__(self):

        self.inputPlaceholders = dict()
        self.targetPlaceholders = dict()
        self.embeddings = dict()
        self.meanCorrectPredictions = dict()
        self.correct_predictions = dict()
        self.nnoutputs_concat = dict()
        self.loss = dict()
        self.cost = dict()
        self.initialState = dict()
        self.states = dict()
        self.trainOp = dict()
        self.oneHotTargets = dict()
        self.softMax = dict()
        self.mergedSummaries = None
        self.summaryWriter = None
        self.learningRate = None
        self.initOp = None
        self.graphSaver = None

        self.g = tf.Graph()


def on_closing():
    pass


class TFGraphBuilder:
    def __init__(self, config, progressBox):

        self.graph = None
        self.config = config
        self.progressBox = progressBox

    def build(self):
        self.graph = TFGraph()

        self.progressBox.updateMessage("Building Predictors")

        with self.graph.g.as_default():

            inputs = dict()
            nninputs = [None]*self.config.numUnrollSteps
            for (attrnames, attrtype, embedSize, val2id, _, _, eventattr) in sorted(self.config.preds):
                if (attrnames, eventattr) in self.config.predictors:
                    # ensure a valid TensorFlow name
                    attrnames = re.sub("[^A-Za-z0-9.\\-]*", "", "-".join(attrnames))
                    with tf.name_scope("predictor/"+attrnames+"Event" if eventattr else "predictor/"+attrnames+"Case"):
                        if attrtype == XESAttributeTypes.DATE or attrtype == XESAttributeTypes.NUMERIC:
                            self.graph.inputPlaceholders[(attrnames, eventattr)] = tf.placeholder(tf.float32, [self.config.batchSize, self.config.numUnrollSteps], name="inputPlaceholder")
                            # We make the actual input two-dimensional, but expand to three dimensions because the embeddings will be three dimensionals
                            inputs[(attrnames, eventattr)] = tf.expand_dims(self.graph.inputPlaceholders[(attrnames, eventattr)], 2, name="expandDims-"+attrnames)
                        elif attrtype == XESAttributeTypes.CATEGORICAL:
                            self.graph.inputPlaceholders[(attrnames, eventattr)] = tf.placeholder(tf.int64, (self.config.batchSize, self.config.numUnrollSteps), name="inputPlaceholder")
                            with tf.device("/cpu:0"):
                                # Embedding lookups can't be done on the GPU
                                self.graph.embeddings[(attrnames, eventattr)] = tf.Variable(tf.random_uniform([len(val2id), embedSize], -self.config.initScale, self.config.initScale), name="embedding")
                                inputs[(attrnames, eventattr)] = tf.nn.embedding_lookup(self.graph.embeddings[(attrnames, eventattr)], self.graph.inputPlaceholders[(attrnames, eventattr)], name="embeddedInput")
                                if self.config.dropoutProb > 0:
                                    inputs[(attrnames, eventattr)] = tf.nn.dropout(inputs[(attrnames, eventattr)], 1.0 - self.config.dropoutProb, name="droppedoutInput")
                        # reshape inputs: take them apart into a set of size numUnrollStep pieces, squeezing any length-1 dimensions out of them
                        inputs[(attrnames, eventattr)] = [tf.squeeze(input_, [1], name="squeezeInput") for input_ in tf.split(axis=1, num_or_size_splits=self.config.numUnrollSteps, value=inputs[(attrnames, eventattr)], name="splitInput")]

            self.progressBox.updateMessage("Building Input Preparation")

            defaultSize = self.config.rnnSize
            with tf.name_scope("inputPrep"):
                for i in range(0, self.config.numUnrollSteps):
                    # concatenate the input tensors for the predictors, do this for each of the numUnrollSteps steps
                    # they get fed to the RNN as a set of such input tensors
                    nninputs[i] = tf.concat(axis=1, values=[row[i] for (_, row) in sorted(inputs.iteritems())], name="nninput-{0}".format(i))
                    # This is the default total size of each input (and therefore also state and output) tensor
                    defaultSize = int(nninputs[i].get_shape()[1])
                    # Let's see if the user wants this adjusted, in which case we scale this
                    if not self.config.autoRNNSize:
                        inputProjectionW = tf.Variable(tf.random_uniform([defaultSize, self.config.rnnSize], -self.config.initScale, self.config.initScale), name="inputProjectionW-{0}".format(i))
                        inputProjectionB = tf.Variable(tf.random_uniform([self.config.rnnSize], -self.config.initScale, self.config.initScale), name="inputProjectionB-{0}".format(i))
                        nninputs[i] = tf.matmul(nninputs[i], inputProjectionW, name="inputProjectionMatMul-{0}".format(i))
                        nninputs[i] = tf.add(nninputs[i], inputProjectionB, name="inputProjectionAdd-{0}".format(i))
                    self.progressBox.makeTick()

            if self.config.autoRNNSize:
                hiddenSize = defaultSize
            else:
                hiddenSize = self.config.rnnSize

            self.progressBox.updateMessage("Building RNN")

            if not self.config.sharedRNN:
                for (attrnames, attrtype, embedSize, val2id, _, _, eventattr) in self.config.preds:
                    if (attrnames, eventattr) in self.config.targets:
                        # ensure a valid TensorFlow name
                        attrnames = re.sub("[^A-Za-z0-9.\\-]*", "", "-".join(attrnames))
                        with tf.name_scope("LSTM/"+attrnames):
                            if self.config.rnnActivationFunc == 'sigmoid':
                                cell = tf.nn.rnn_cell.LSTMCell(hiddenSize, use_peepholes=self.config.usePeepHoles, forget_bias=0.0, activation=tf.sigmoid)
                            else:
                                cell = tf.nn.rnn_cell.LSTMCell(hiddenSize, use_peepholes=self.config.usePeepHoles, forget_bias=0.0, activation=tf.tanh)
                            # Add a propabilistic dropout to cells of the LSTM layer
                            if self.config.dropoutProb > 0:
                                cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=1.0 - self.config.dropoutProb)
                            # Replicate this (including dropout) to additional layers
                            cell = tf.nn.rnn_cell.MultiRNNCell([cell] * self.config.numLayers)
                            # initial state of all cells is zero
                            self.graph.initialState[attrnames] = cell.zero_state(self.config.batchSize, tf.float32)

                            # print("RNN with I/O Connections")
                            nnoutputs, self.graph.states[attrnames] = tf.nn.static_rnn(cell, nninputs, initial_state=self.graph.initialState[attrnames], scope="RNN/"+attrnames)
                            tf.convert_to_tensor(self.graph.states[attrnames], name="FinalStateTensor")
                            self.progressBox.makeTick()

                        with tf.name_scope("target/" + attrnames):
                            if attrtype == XESAttributeTypes.DATE or attrtype == XESAttributeTypes.NUMERIC:
                                self.graph.targetPlaceholders[attrnames] = tf.placeholder(tf.float32, [self.config.batchSize, self.config.numUnrollSteps], name="targetPlaceholder")
                                projectedOutputs = [None]*self.config.numUnrollSteps
                                for i in range(0, self.config.numUnrollSteps):
                                    outputProjectionW = tf.Variable(tf.random_uniform([hiddenSize, 1], -self.config.initScale, self.config.initScale), name="outputProjectionW-{0}".format(i))
                                    outputProjectionB = tf.Variable(initial_value=0, dtype=tf.float32, name="outputProjectionB-{0}".format(i))
                                    projectedOutputs[i] = tf.matmul(nnoutputs[i], outputProjectionW, name="outputProjectionMatMul-{0}".format(i))
                                    projectedOutputs[i] = tf.add(projectedOutputs[i], outputProjectionB, name="outputProjectionAdd-{0}".format(i))
                                self.graph.nnoutputs_concat[attrnames] = tf.concat(axis=1, values=projectedOutputs, name="outputConcat")
                                if self.config.numLossFunc == 'MSE':
                                    self.graph.loss[attrnames] = tf.reduce_mean(tf.squared_difference(self.graph.nnoutputs_concat[attrnames], self.graph.targetPlaceholders[attrnames], name="squarediff"), name="meanloss")
                                elif self.config.numLossFunc == 'MAE':
                                    self.graph.loss[attrnames] = tf.reduce_mean(tf.abs(self.graph.nnoutputs_concat[attrnames] - self.graph.targetPlaceholders[attrnames], name="absdiff"), name="meanloss")
                                elif self.config.numLossFunc == 'RMSE':
                                    self.graph.loss[attrnames] = tf.sqrt(tf.reduce_mean(tf.squared_difference(self.graph.nnoutputs_concat[attrnames], self.graph.targetPlaceholders[attrnames], name="squarediff"), name="meanDiff"), name="meanloss")
                                tf.summary.scalar('Loss', self.graph.loss[attrnames])
                                # We put this in here so that we can fetch "NumCorrectPredictions" regardless of attribute type
                                self.graph.meanCorrectPredictions[attrnames] = tf.zeros(shape=(), dtype=tf.float32)
                            if attrtype == XESAttributeTypes.CATEGORICAL:
                                self.graph.targetPlaceholders[attrnames] = tf.placeholder(tf.int64, [self.config.batchSize, self.config.numUnrollSteps], name="targetPlaceholder")
                                projectedOutputs = [None]*self.config.numUnrollSteps
                                for i in range(0, self.config.numUnrollSteps):
                                    outputProjectionW = tf.Variable(tf.random_uniform([hiddenSize, len(val2id)], -self.config.initScale, self.config.initScale), name="outputProjectionW-{0}".format(i))
                                    outputProjectionB = tf.Variable(tf.random_uniform([len(val2id)], -self.config.initScale, self.config.initScale), name="outputProjectionB-{0}".format(i))
                                    projectedOutputs[i] = tf.matmul(nnoutputs[i], outputProjectionW, name="outputProjectionMatMul-{0}".format(i))
                                    projectedOutputs[i] = tf.add(projectedOutputs[i], outputProjectionB, name="outputProjectionAdd-{0}".format(i))
                                projectedOutputs = tf.stack(values=projectedOutputs, axis=1, name="outputPack")

                                self.graph.oneHotTargets[attrnames] = tf.one_hot(indices=self.graph.targetPlaceholders[attrnames], depth=len(val2id), axis=-1, name="onehot")
                                self.graph.softMax[attrnames] = tf.nn.softmax(projectedOutputs, dim=-1, name="softMax")
                                self.graph.loss[attrnames] = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=projectedOutputs, labels=self.graph.oneHotTargets[attrnames], dim=-1, name="loss"), name="meanloss")
                                tf.summary.scalar('Loss', self.graph.loss[attrnames])
                                self.graph.correct_predictions[attrnames] = tf.cast(tf.nn.in_top_k(tf.reshape(self.graph.softMax[attrnames], [-1, len(val2id)]), tf.reshape(self.graph.targetPlaceholders[attrnames], [-1]), 1), tf.float32)
                                self.graph.meanCorrectPredictions[attrnames] = tf.reduce_mean(self.graph.correct_predictions[attrnames], name="SumCorrectPredictions")
                                tf.summary.scalar('Mean Correct Predictions', self.graph.meanCorrectPredictions[attrnames])
                            self.progressBox.makeTick()
            else:
                with tf.name_scope("LSTM"):
                    if self.config.rnnActivationFunc == 'sigmoid':
                        cell = tf.nn.rnn_cell.LSTMCell(hiddenSize, use_peepholes=self.config.usePeepHoles, forget_bias=0.0, activation=tf.sigmoid)
                    else:
                        cell = tf.nn.rnn_cell.LSTMCell(hiddenSize, use_peepholes=self.config.usePeepHoles, forget_bias=0.0, activation=tf.tanh)
                    # Add a propabilistic dropout to cells of the LSTM layer
                    if self.config.dropoutProb > 0:
                        cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=1.0 - self.config.dropoutProb)
                    # Replicate this (including dropout) to additional layers
                    cell = tf.nn.rnn_cell.MultiRNNCell([cell] * self.config.numLayers)
                    # initial state of all cells is zero
                    self.graph.initialState = cell.zero_state(self.config.batchSize, tf.float32)

                    # print("RNN with I/O Connections")
                    nnoutputs, self.graph.states = tf.nn.static_rnn(cell, nninputs, initial_state=self.graph.initialState, scope="RNN")
                    tf.convert_to_tensor(self.graph.states, name="FinalStateTensor")
                    self.progressBox.makeTick()

                for (attrnames, attrtype, embedSize, val2id, _, _, eventattr) in self.config.preds:
                    if (attrnames, eventattr) in self.config.targets:
                        # ensure a valid TensorFlow name
                        attrnames = re.sub("[^A-Za-z0-9.\\-]*", "", "-".join(attrnames))
                        with tf.name_scope("target/" + attrnames):
                            if attrtype == XESAttributeTypes.DATE or attrtype == XESAttributeTypes.NUMERIC:
                                self.graph.targetPlaceholders[attrnames] = tf.placeholder(tf.float32, [self.config.batchSize, self.config.numUnrollSteps], name="targetPlaceholder")
                                projectedOutputs = [None]*self.config.numUnrollSteps
                                for i in range(0, self.config.numUnrollSteps):
                                    outputProjectionW = tf.Variable(tf.random_uniform([hiddenSize, 1], -self.config.initScale, self.config.initScale), name="outputProjectionW-{0}".format(i))
                                    outputProjectionB = tf.Variable(initial_value=0, dtype=tf.float32, name="outputProjectionB-{0}".format(i))
                                    projectedOutputs[i] = tf.matmul(nnoutputs[i], outputProjectionW, name="outputProjectionMatMul-{0}".format(i))
                                    projectedOutputs[i] = tf.add(projectedOutputs[i], outputProjectionB, name="outputProjectionAdd-{0}".format(i))
                                self.graph.nnoutputs_concat[attrnames] = tf.concat(axis=1, values=projectedOutputs, name="outputConcat")
                                if self.config.numLossFunc == 'MSE':
                                    self.graph.loss[attrnames] = tf.reduce_mean(tf.squared_difference(self.graph.nnoutputs_concat[attrnames], self.graph.targetPlaceholders[attrnames], name="squarediff"), name="meanloss")
                                elif self.config.numLossFunc == 'MAE':
                                    self.graph.loss[attrnames] = tf.reduce_mean(tf.abs(self.graph.nnoutputs_concat[attrnames] - self.graph.targetPlaceholders[attrnames], name="absdiff"), name="meanloss")
                                elif self.config.numLossFunc == 'RMSE':
                                    self.graph.loss[attrnames] = tf.sqrt(tf.reduce_mean(tf.squared_difference(self.graph.nnoutputs_concat[attrnames], self.graph.targetPlaceholders[attrnames], name="squarediff"), name="meanDiff"), name="meanloss")
                                tf.summary.scalar('Loss', self.graph.loss[attrnames])
                                # We put this in here so that we can fetch "NumCorrectPredictions" regardless of attribute type
                                self.graph.meanCorrectPredictions[attrnames] = tf.zeros(shape=(), dtype=tf.float32)
                            if attrtype == XESAttributeTypes.CATEGORICAL:
                                self.graph.targetPlaceholders[attrnames] = tf.placeholder(tf.int64, [self.config.batchSize, self.config.numUnrollSteps], name="targetPlaceholder")
                                projectedOutputs = [None]*self.config.numUnrollSteps
                                for i in range(0, self.config.numUnrollSteps):
                                    outputProjectionW = tf.Variable(tf.random_uniform([hiddenSize, len(val2id)], -self.config.initScale, self.config.initScale), name="outputProjectionW-{0}".format(i))
                                    outputProjectionB = tf.Variable(tf.random_uniform([len(val2id)], -self.config.initScale, self.config.initScale), name="outputProjectionB-{0}".format(i))
                                    projectedOutputs[i] = tf.matmul(nnoutputs[i], outputProjectionW, name="outputProjectionMatMul-{0}".format(i))
                                    projectedOutputs[i] = tf.add(projectedOutputs[i], outputProjectionB, name="outputProjectionAdd-{0}".format(i))
                                projectedOutputs = tf.stack(values=projectedOutputs, axis=1, name="outputPack")

                                self.graph.oneHotTargets[attrnames] = tf.one_hot(indices=self.graph.targetPlaceholders[attrnames], depth=len(val2id), axis=-1, name="onehot")
                                self.graph.softMax[attrnames] = tf.nn.softmax(projectedOutputs, dim=-1, name="softMax")
                                self.graph.loss[attrnames] = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=projectedOutputs, labels=self.graph.oneHotTargets[attrnames], dim=-1, name="loss"), name="meanloss")
                                tf.summary.scalar('Loss', self.graph.loss[attrnames])
                                self.graph.correct_predictions[attrnames] = tf.cast(tf.nn.in_top_k(tf.reshape(self.graph.softMax[attrnames], [-1, len(val2id)]), tf.reshape(self.graph.targetPlaceholders[attrnames], [-1]), 1), tf.float32)
                                self.graph.meanCorrectPredictions[attrnames] = tf.reduce_mean(self.graph.correct_predictions[attrnames], name="SumCorrectPredictions")
                                tf.summary.scalar('Mean Correct Predictions', self.graph.meanCorrectPredictions[attrnames])
                            self.progressBox.makeTick()

            self.progressBox.updateMessage("Training Setup")
            self.graph.learningRate = tf.Variable(0.0, dtype=tf.float32, trainable=False, name="LearningRate")

            # Define the optimizer, default to gradient descent
            optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.graph.learningRate)
            if self.config.optimizer == "Adadelta":
                optimizer = tf.train.AdadeltaOptimizer(learning_rate=self.graph.learningRate, rho=self.config.optimPar1)
            elif self.config.optimizer == "Adagrad":
                optimizer = tf.train.AdagradOptimizer(learning_rate=self.graph.learningRate, initial_accumulator_value=self.config.optimPar1)
            elif self.config.optimizer == "Momentum":
                optimizer = tf.train.MomentumOptimizer(learning_rate=self.graph.learningRate, momentum=self.config.optimPar1)
            elif self.config.optimizer == "Adam":
                optimizer = tf.train.AdamOptimizer(learning_rate=self.graph.learningRate, beta1=self.config.optimPar1, beta2=self.config.optimPar2)
            elif self.config.optimizer == "FTRL":
                optimizer = tf.train.FtrlOptimizer(learning_rate=self.graph.learningRate, learning_rate_power=self.config.optimPar1, initial_accumulator_value=self.config.optimPar2)
            elif self.config.optimizer == "ProximalGradient":
                optimizer = tf.train.ProximalGradientDescentOptimizer(learning_rate=self.graph.learningRate, l1_regularization_strength=self.config.optimPar1, l2_regularization_strength=self.config.optimPar2)
            elif self.config.optimizer == "ProximalAda":
                optimizer = tf.train.ProximalAdagradOptimizer(learning_rate=self.graph.learningRate, initial_accumulator_value=self.config.optimPar1)
            elif self.config.optimizer == "RMSpro":
                optimizer = tf.train.RMSPropOptimizer(learning_rate=self.graph.learningRate, decay=self.config.optimPar1, momentum=self.config.optimPar2)

            self.progressBox.updateMessage("Computing gradients")
            # get the trainable variables
            trainableVars = tf.trainable_variables()
            # We need to optimize each loss function separately, so we'll define multiple training operations
            for (attrnames, _, _, _, _, _, eventattr) in self.config.preds:
                if (attrnames, eventattr) in self.config.targets:
                    # ensure a valid TensorFlow name
                    attrnames = re.sub("[^A-Za-z0-9.\\-]*", "", "-".join(attrnames))
                    # compute gradients of all trainable variables w.r.t cost, then clip the gradients, prevent from getting too large too fast
                    grads, _ = tf.clip_by_global_norm(tf.gradients(self.graph.loss[attrnames], trainableVars), self.config.maxGradNorm)
                    # and tell it to work on the gradients for the trainable variables
                    self.graph.trainOp[attrnames] = optimizer.apply_gradients(zip(grads, trainableVars))
                    self.progressBox.makeTick()

            self.graph.mergedSummaries = tf.summary.merge_all()
            self.graph.initOp = tf.global_variables_initializer()

        return self.graph
