"""
Joerg Evermann, 2016

Adapted from and based upon the TensorFlow RNN tutorial provided by Google

"""

import tensorflow as tf
#from tensorflow.models.rnn import rnn
import rnn_cell
import rnn
import numpy
import time
import os
import h5py
import readWords
import string
import pyxdameraulevenshtein as dl

batchSize = 20
numUnrollSteps = 5
hiddenSize = 32
dropoutProb = 0.2
numLayers = 2
maxGradNorm = 5
initScale = 0.10

numEpochsFullLR = 50
numEpochs = 100

baseLearningRate = 1.0
lrDecay = 0.75
basicLSTM = False
forgetBias = 0.1
usePeepHoles = True

numRuns = 1
# use NumFolds = 1 if you don't want cross-validation
numFolds = 1
shuffle = False
dataOverlap = False

#
# The following control the hallucinations
#
# Sampling Seed Type
# 0 : reuse previous state
# 1 : random inits
# 2 : Use first traces
samplingSeedType = 0
# sampling seed length is only used when sampling type is 1 or 2
samplingSeedLen = 5
# Sampling Mode
# 0 : element wise
# 1 : in sets of numUnrollSteps
samplingMode = 0
# Sampling Kind
# 0 : Max
# 1 : Probabilistically
samplingKind = 1
# These are used for the hallucations
endOfCaseChar = '[EOC]'
# If you don't want to generate hallucinations, set this to 0
maxHallucinations = 0
# maxHallucinations = 50 * numUnrollSteps
#
# If you want to export states for visualization, set this to 1
#
exportStates = 0
#
# If you want to export the final embedding, set this to 1
#
exportEmbedding = 0
#
# If you want to predict suffixes, set this to 1
#
predictSuffixes = 1
#
# We'll only produce suffixes of this length to check for eoc
maxSuffixLen = 20 * numUnrollSteps
maxNumBatches = 100
#
#

resultFile = open("ResultFile" + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + ".logging.csv", 'w')

resultFile.write("numFolds, %0.d\n" %(numFolds))
resultFile.write("numRuns, %0.d\n" %(numRuns))

resultFile.write("Batchsize, %0.d\n" %(batchSize))
resultFile.write("numUnrollSteps, %0.d\n" %(numUnrollSteps))
resultFile.write("hiddenSize, %0.d\n" %(hiddenSize))
resultFile.write("droputProb, %.3f\n" %(dropoutProb))
resultFile.write("numLayers, %0.d\n" %(numLayers))
resultFile.write("maxGradNorm, %.3f\n" %(maxGradNorm))
resultFile.write("initScale, %.3f\n" %(initScale))
resultFile.write("numEpochsFullLR, %0.d\n" %(numEpochsFullLR))
resultFile.write("numEpochs, %0.d\n" %(numEpochs))
resultFile.write("baseLearningRate, %.3f\n" %(baseLearningRate))
resultFile.write("lrDecay, %.3f\n" %(lrDecay))
resultFile.write("forgetBias, %.3f\n" %(forgetBias))
if (basicLSTM):
    resultFile.write("BasicLSTM\n")
else:
    resultFile.write("LSTM\n")
if (usePeepHoles):
    resultFile.write("Peepholes\n")
else:
    resultFile.write("NoPeepholes\n")
if (shuffle):
    resultFile.write("Shuffle\n")
else:
    resultFile.write("NoShuffle\n")

resultFile.write("\nDataset, VocabSize, TrainWords, ValidWords, Run, Fold, Epoch, TrainPrecision, TrainPerplexity, TrainCrossEntropy, Epoch, ValidPrecision, ValidPerplexity, ValidCrossEntropy\n")
resultFile.flush()
os.fsync(resultFile.fileno())

path = "./bpi_data"

for dataset in ([
    # Table 1
        ["BPI_Challenge_2013_incidents.extract.txt","BPI_Challenge_2013_incidents.extract.txt"],
        ["BPI_Challenge_2013_problems.extract.txt","BPI_Challenge_2013_problems.extract.txt"],

        ["BPIC_Challenge_2012.extract.complete.txt","BPIC_Challenge_2012.extract.complete.txt"],
    #    ["BPIC_Challenge_2012.extract.txt","BPIC_Challenge_2012.extract.txt"],

        ["BPIC_Challenge_2012_W.extract.complete.txt","BPIC_Challenge_2012_W.extract.complete.txt"],
    #    ["BPIC_Challenge_2012_A.extract.complete.txt","BPIC_Challenge_2012_A.extract.complete.txt"],
    #    ["BPIC_Challenge_2012_O.extract.complete.txt","BPIC_Challenge_2012_O.extract.complete.txt"],

        ["BPIC_Challenge_2012_W.extract.txt","BPIC_Challenge_2012_W.extract.txt"],
        ["BPIC_Challenge_2012_A.extract.txt","BPIC_Challenge_2012_A.extract.txt"],
        ["BPIC_Challenge_2012_O.extract.txt","BPIC_Challenge_2012_O.extract.txt"],
    # Table 2
    #    ["BPI_Challenge_2013_incidents.extract.with.group.txt","BPI_Challenge_2013_incidents.extract.with.group.txt"],
    #    ["BPI_Challenge_2013_problems.extract.with.group.txt","BPI_Challenge_2013_problems.extract.with.group.txt"],

    #    ["BPIC_Challenge_2012.extract.complete.with.resource.txt","BPIC_Challenge_2012.extract.complete.with.resource.txt"],
    #    ["BPIC_Challenge_2012.extract.with.resource.txt","BPIC_Challenge_2012.extract.with.resource.txt"],

    #    ["BPIC_Challenge_2012_W.extract.complete.with.resource.txt","BPIC_Challenge_2012_W.extract.complete.with.resource.txt"],
    #    ["BPIC_Challenge_2012_A.extract.complete.with.resource.txt", "BPIC_Challenge_2012_A.extract.complete.with.resource.txt"],
    #    ["BPIC_Challenge_2012_O.extract.complete.with.resource.txt", "BPIC_Challenge_2012_O.extract.complete.with.resource.txt"],

    #    ["BPIC_Challenge_2012_W.extract.with.resource.txt", "BPIC_Challenge_2012_W.extract.with.resource.txt"],
    #    ["BPIC_Challenge_2012_A.extract.with.resource.txt","BPIC_Challenge_2012_A.extract.with.resource.txt"],
    #    ["BPIC_Challenge_2012_O.extract.with.resource.txt","BPIC_Challenge_2012_O.extract.with.resource.txt"],

    # Table 3
    #    ["BPI_Challenge_2013_incidents.extract.with.group.txt","BPI_Challenge_2013_incidents.extract.txt"],
    #    ["BPI_Challenge_2013_problems.extract.with.group.txt","BPI_Challenge_2013_problems.extract.txt"],

    #    ["BPIC_Challenge_2012.extract.complete.with.resource.txt","BPIC_Challenge_2012.extract.complete.txt"],
    #    ["BPIC_Challenge_2012.extract.with.resource.txt","BPIC_Challenge_2012.extract.txt"],

    #    ["BPIC_Challenge_2012_W.extract.complete.with.resource.txt","BPIC_Challenge_2012_W.extract.complete.txt"],
    #    ["BPIC_Challenge_2012_A.extract.complete.with.resource.txt","BPIC_Challenge_2012_A.extract.complete.txt"],
    #    ["BPIC_Challenge_2012_O.extract.complete.with.resource.txt","BPIC_Challenge_2012_O.extract.complete.txt"],

    #    ["BPIC_Challenge_2012_W.extract.with.resource.txt", "BPIC_Challenge_2012_W.extract.txt"],
    #    ["BPIC_Challenge_2012_A.extract.with.resource.txt","BPIC_Challenge_2012_A.extract.txt"],
    #    ["BPIC_Challenge_2012_O.extract.with.resource.txt","BPIC_Challenge_2012_O.extract.txt"]

    # Extra for the DSS revision
    #     ["bpic2012_w_durations.txt", "bpic2012_w_durations.txt"]
]):
    # Read the data
    inputVocabulary = readWords.vocabSize(path, dataset[0])
    targetVocabulary = readWords.vocabSize(path, dataset[1])

    print("Input  Dataset: " + dataset[0] + ", Vocabulary Size: %.0f\n" % (inputVocabulary))
    print("Target Dataset: " + dataset[1] + ", Vocabulary Size: %.0f\n" % (targetVocabulary))

    # Need to rebuild the graph for each dataset because of the different vocabulary sizes
    tf.reset_default_graph()

    # Create placeholders for inputs and targets into which we can insert data at runtime
    input_data = tf.placeholder(tf.int64, [batchSize, numUnrollSteps], name="InputData")
    targets = tf.placeholder(tf.int64, [batchSize, numUnrollSteps], name="TargetData")
    with tf.name_scope("lstm"):
        # Create the individual LSTM layers
        if (basicLSTM == True):
            lstm_cell = rnn_cell.BasicLSTMCell(hiddenSize, forget_bias=0.0)
        else:
            lstm_cell = rnn_cell.LSTMCell(hiddenSize, use_peepholes=usePeepHoles, forget_bias=0.0)
        # Add a propabilistic dropout to cells of the LSTM layer
        if (dropoutProb > 0):
            lstm_cell = rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=1.0-dropoutProb)
        # Replicate this (including dropout) to additional layers
        cell = rnn_cell.MultiRNNCell([lstm_cell] * numLayers)
    # initial state of all cells is zero
    initialState = cell.zero_state(batchSize, tf.float32)

    with tf.device("/cpu:0"):
        # Embedding lookups can't be done on the GPU
        embedding = tf.Variable(tf.random_uniform([inputVocabulary, hiddenSize], -initScale, initScale),  name="Embedding")
        inputs = tf.nn.embedding_lookup(embedding, input_data, name="EmbeddingLookup")
        # Add a probabilistic dropout to inputs as well
        if (dropoutProb > 0):
            inputs = tf.nn.dropout(inputs, 1.0-dropoutProb, name="InputDropout")

    # reshape inputs
    inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(1, numUnrollSteps, inputs)]
    # connect inputs and outputs to the RNN layers
    outputs, states = rnn.rnn(cell, inputs, initial_state=initialState)
    outputs_concat = tf.concat(1, outputs)
    states_concat = tf.concat(1, states)
    # reshape outputs appropriately
    output = tf.reshape(outputs_concat, [-1, hiddenSize])

    # Do a softmax to identify the highest probability word
    softmax_w = tf.Variable(tf.random_uniform([hiddenSize, targetVocabulary], -initScale, initScale), name="softmax_w")
    softmax_b = tf.Variable(tf.random_uniform([targetVocabulary],  -initScale, initScale), name="softmax_b")
    logits = tf.matmul(output, softmax_w, name="MatMul") + softmax_b
    softmax = tf.nn.softmax(logits, name="softmax")
    loss = tf.nn.seq2seq.sequence_loss_by_example( [logits], [tf.reshape(targets, [-1])], weights=[tf.ones([batchSize * numUnrollSteps])], name="SeqLossByExample" )
    cost = tf.reduce_sum(loss, name="SumLoss") / batchSize
    finalState = states[numUnrollSteps-1]

    # Compute the accuracy
    correct_prediction = tf.cast(tf.nn.in_top_k(logits, tf.reshape(targets, [-1]), 1), tf.float32)
    _ , prediction_indices = tf.nn.top_k(softmax, 1)
    prediction_indices = tf.reshape(prediction_indices, [batchSize, -1])
    numCorrectPredictions = tf.reduce_sum(correct_prediction, name="SumCorrectPredictions")
    accuracy = tf.reduce_mean(correct_prediction, name="MeanCorrectPredictions")

    # Compute the cross entropy
    oneHotTargets = tf.one_hot(targets, depth=targetVocabulary, on_value=1.0, off_value=0.0)
    reshapedTargets = tf.reshape(oneHotTargets, shape=(batchSize*numUnrollSteps, targetVocabulary))
    crossEntropy = tf.reduce_mean(-tf.reduce_sum(reshapedTargets * tf.log(tf.sigmoid(logits)), reduction_indices=[1]), reduction_indices=[0], name="MeanCrossEntropy")

    # learning rate is a variable, but not trainable
    learningRate = tf.Variable(0.0, trainable=False, name="LearningRate")
    trainableVars = tf.trainable_variables()
    # compute gradients of all trainable variables w.r.t cost, then clip the gradients, prevent from getting too large too fast
    grads, _ = tf.clip_by_global_norm(tf.gradients(cost, trainableVars), maxGradNorm)
    # Define the optimizer
    optimizer = tf.train.GradientDescentOptimizer(learningRate)
    # and tell it to work on the gradients for the trainable variables
    train_op = optimizer.apply_gradients(zip(grads, trainableVars))

    # with the computational network defined, we can now move to runtime
    for runNum in range(0, numRuns):
        # We call dataFolds() again, because it randomly shuffles the dataset, so that we get different folds for the next run
        inputFolds, targetFolds, inputWord2Id, targetWord2Id, inputVocabulary, targetVocabulary = readWords.dataFolds(path, dataset[0], dataset[1], numFolds, shuffle)
        # dataFolds, word2id, vocabSize = readWords.dataFolds(path, dataset, numFolds)

        for fold in range(0, numFolds):

            # Get the validation words and training words
            inputValidateWords, inputTrainWords, targetValidateWords, targetTrainWords = readWords.validationAndTrainingData(inputFolds, targetFolds, inputWord2Id, targetWord2Id, fold)

            if (dataOverlap == False):
                epochNumBatches = ((len(inputTrainWords) // batchSize) - 1) // numUnrollSteps
            else:
                epochNumBatches = ((len(inputTrainWords) // batchSize) - numUnrollSteps)

            # Initialize all variables in the computational graph
            init_op = tf.initialize_all_variables()
            sess = tf.Session()
            sess.run(init_op)

            # Keep track of the best
            bestTrainPrecision = 0
            bestTrainPerplexity = 0
            bestTrainCrossEntropy = 0
            bestTrainEpoch = 0
            # Keep track of the best
            bestValidPrecision = 0
            bestValidPerplexity = 0
            bestValidCrossEntropy = 0
            bestValidEpoch = 0
            # Training starts here
            for i in range(numEpochs):
                # Adjust the learning rate by decaying it
                # and set the appropriate variable in the graph
                sess.run(tf.assign(learningRate, baseLearningRate*lrDecay**max(i - numEpochsFullLR, 0.0)))
                # Get the learning rate and print it
                print("Dataset: [" + dataset[0] + "/" + dataset[1] + "] Run: %d Fold: %d Epoch: %d Learning rate: %.3f" % (runNum, fold, i + 1, sess.run(learningRate)))

                start_time = time.time()
                # accumulated cross entropy
                accumCrossEnt = 0.0
                # accumulated costs over the unroll steps
                accumCosts = 0.0
                # number of correct predictions
                accumNumCorrPred = 0
                # number of iterations/unroll steps
                iters = 0
                state = initialState.eval(session=sess)
                for batchNum, (x, y) in enumerate(readWords.words_iterator(inputTrainWords, targetTrainWords, batchSize, numUnrollSteps, dataOverlap)):
                    batchNumCorrPred, batchCrossEnt, batchCost, state, _ = sess.run([numCorrectPredictions, crossEntropy, cost, finalState, train_op], {input_data: x, targets: y, initialState: state})
                    accumCosts += batchCost
                    accumNumCorrPred += batchNumCorrPred
                    accumCrossEnt += batchCrossEnt
                    iters += numUnrollSteps
                    if (epochNumBatches > 10):
                        if batchNum % (epochNumBatches // 10) == 10:
                            print("Dataset: [" + dataset[0] + "/" + dataset[1] + "] Run: %d Fold: %d Epoch percent: %.3f perplexity: %.3f speed: %.0f wps number of correct predictions: %.0f precision: %.3f cross-entropy: %.3f" %
                                (runNum, fold, batchNum * 1.0 / epochNumBatches, numpy.exp(accumCosts / iters), iters * batchSize / (time.time() - start_time), accumNumCorrPred, accumNumCorrPred / (iters * batchSize), accumCrossEnt / batchNum))
                thisPrecision = accumNumCorrPred / (iters * batchSize)
                thisPerplexity = numpy.exp(accumCosts / iters)
                thisCrossEntropy = accumCrossEnt / batchNum
                print("Dataset: [" + dataset[0] + "/" + dataset[1] + "] Run: %d Fold: %d Epoch summary: perplexity: %.3f speed: %.0f wps number of correct predictions: %.0f precision: %.3f cross-entropy: %.3f" %
                      (runNum, fold, numpy.exp(accumCosts / iters), iters * batchSize / (time.time() - start_time), accumNumCorrPred, thisPrecision, thisCrossEntropy))

                if (thisPrecision > bestTrainPrecision):
                    bestTrainPrecision = thisPrecision
                    bestTrainPerplexity = thisPerplexity
                    bestTrainCrossEntropy = thisCrossEntropy
                    bestTrainEpoch = i

                resultFile.write("[" + dataset[0] + "/" + dataset[1] + "], %d, %d, %d, %d, %d, %.0f, %.3f, %.3f, %.3f\n" % (inputVocabulary, len(inputTrainWords), len(inputValidateWords), runNum, fold, i, thisPrecision, thisPerplexity, thisCrossEntropy))
                resultFile.flush()
                os.fsync(resultFile.fileno())

            finalTrainingState = state
            if (numFolds > 1):
                # Validation starts here. Differences are the use of validwords for data and tf.no_op() instead of the train_op() built into the graph
                if (dataOverlap==False):
                    validateNumBatches = ((len(inputValidateWords) // batchSize) - 1) // numUnrollSteps
                else:
                    validateNumBatchs = ((len(inputValidateWords)// batchSize - numUnrollSteps))
                start_time = time.time()
                # accumulated cross entropy
                accumCrossEnt = 0.0
                # accumulated costs over the unroll steps
                accumCosts = 0.0
                # number of correct predictions
                accumNumCorrPred = 0
                # number of iterations/unroll steps
                iters = 0
                state = initialState.eval(session=sess)
                # state = finalTrainingState
                for batchNum, (x, y) in enumerate(readWords.words_iterator(inputValidateWords, targetValidateWords, batchSize, numUnrollSteps, dataOverlap)):
                    batchNumCorrPred, batchCrossEnt, batchCost, state, _ = sess.run([numCorrectPredictions, crossEntropy, cost, finalState, tf.no_op()], {input_data: x, targets: y, initialState: state})
                    accumCosts += batchCost
                    accumNumCorrPred += batchNumCorrPred
                    accumCrossEnt += batchCrossEnt
                    iters += numUnrollSteps
                    if (validateNumBatches > 10):
                        if batchNum % (validateNumBatches // 10) == 10:
                            print(
                            "Validation Epoch percent: %.3f perplexity: %.3f speed: %.0f wps number of correct predictions: %.0f precision: %.3f cross-entropy: %.3f" %
                            (batchNum * 1.0 / validateNumBatches, numpy.exp(accumCosts / iters),
                             iters * batchSize / (time.time() - start_time), accumNumCorrPred,
                             accumNumCorrPred / (iters * batchSize), accumCrossEnt / batchNum))
                thisPrecision = accumNumCorrPred / (iters * batchSize)
                thisPerplexity = numpy.exp(accumCosts / iters)
                thisCrossEntropy = accumCrossEnt / batchNum
                print(
                "Validation summary: perplexity: %.3f speed: %.0f wps number of correct predictions: %.0f precision: %.3f cross-entropy: %.3f" %
                (numpy.exp(accumCosts / iters), iters * batchSize / (time.time() - start_time), accumNumCorrPred, thisPrecision, thisCrossEntropy))
                if (thisPrecision > bestValidPrecision):
                    bestValidPrecision = thisPrecision
                    bestValidPerplexity = thisPerplexity
                    bestValidCrossEntropy = thisCrossEntropy
                    bestValidEpoch = i

            resultFile.write("%d, %.3f, %.3f, %.3f\n" % (i, thisPrecision, thisPerplexity, thisCrossEntropy))
            resultFile.flush()
            os.fsync(resultFile.fileno())

            # Invert the Word dictionary
            inputId2Word = {v: k for k, v in inputWord2Id.iteritems()}

            if maxHallucinations > 0:
                # Hallucinations start here,
                endOfCase = inputWord2Id[endOfCaseChar]
                # An zero array for the targets (which we don't need)
                y = numpy.zeros((batchSize, numUnrollSteps), dtype=numpy.int32)
                # Results go here
                h = [ [None]*batchSize for i in range(maxHallucinations)]
#                outputFile = open("Hallucation{0}Run{1}Fold{2}.txt".format(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()), runNum, fold), "w")
#                outputFile.write("InputDataSet: " + dataset[0] + "\n")
#                outputFile.write("numFolds: %0.d\n" % (numFolds))
#                outputFile.write("numRuns: %0.d\n" % (numRuns))
#                outputFile.write("Batchsize: %0.d\n" % (batchSize))
#                outputFile.write("numUnrollSteps: %0.d\n" % (numUnrollSteps))
#                outputFile.write("hiddenSize: %0.d\n" % (hiddenSize))
#                outputFile.write("droputProb: %.3f\n" % (dropoutProb))
#                outputFile.write("numLayers: %0.d\n" % (numLayers))
#                outputFile.write("maxGradNorm: %.3f\n" % (maxGradNorm))
#                outputFile.write("initScale: %.3f\n" % (initScale))
#                outputFile.write("numEpochsFullLR: %0.d\n" % (numEpochsFullLR))
#                outputFile.write("numEpochs: %0.d\n" % (numEpochs))
#                outputFile.write("baseLearningRate: %.3f\n" % (baseLearningRate))
#                outputFile.write("lrDecay: %.3f\n" % (lrDecay))
#                outputFile.write("forgetBias: %.3f\n" % (forgetBias))
#                if (basicLSTM):
#                    outputFile.write("BasicLSTM\n")
#                else:
#                    outputFile.write("LSTM\n")
#                if (usePeepHoles):
#                    outputFile.write("Peepholes\n")
#                else:
#                    outputFile.write("NoPeepholes\n")
#                if (shuffle):
#                    outputFile.write("Shuffle\n")
#                else:
#                    outputFile.write("NoShuffle\n")
#                outputFile.write("\n")

                for samplingSeedType in [0,1,2]:
                    for samplingMode in [0,1]:
                        for samplingKind in [0,1]:
                            print("Hallucinating ...\n")

                            if samplingSeedType==0:
                                state = finalTrainingState
                                x = numpy.full((batchSize, numUnrollSteps), endOfCase, dtype=numpy.int32)
                                # x = numpy.zeros((batchSize, numUnrollSteps), dtype=numpy.int32)
                                # x[:,0] = batchPredictions[:,numUnrollSteps]
                            if samplingSeedType==1:
                                state = initialState.eval(session=sess)
                                numpy.random.seed(123)
                                for i in range(samplingSeedLen):
                                    # seeding with random integers
                                    x = numpy.random.randint(0, inputVocabulary, (batchSize, numUnrollSteps))
                                    batchPredictions, batchFinalState, _ = sess.run([prediction_indices, finalState, tf.no_op()], {input_data: x, targets: y, initialState: state})
                                    state = batchFinalState
                                x = batchPredictions
                                # x[:,0] = batchPredictions[:,numUnrollSteps]
                            if samplingSeedType==2:
                                state = initialState.eval(session=sess)
                                iter = enumerate(readWords.words_iterator(inputValidateWords, targetValidateWords, batchSize, numUnrollSteps, dataOverlap))
                                for i in range(samplingSeedLen):
                                    _, (x, y) = next(iter)
                                    batchPredictions, batchFinalState, _ = sess.run([prediction_indices, finalState, tf.no_op()], {input_data: x, targets: y, initialState: state})
                                    state = batchFinalState
                                x = batchPredictions
                                # x[:,0] = batchPredictions[:,numUnrollSteps]

                            # Sampling Mode 0 means we sample element wise
                            if samplingMode == 0:
                                for i in range(maxHallucinations):
                                    s = i % numUnrollSteps
                                    batchPredictions, batchSoftMax, batchFinalState, _ = sess.run([prediction_indices, softmax, finalState, tf.no_op()], {input_data: x, targets: y, initialState: state})
                                    if samplingKind == 0:
                                        # We sample max (comes out of the TF graph)
                                        predVal = batchPredictions[:,s]
                                    if samplingKind == 1:
                                        # We sample randomly, using the softmax probs in numpy
                                        batchSoftMax = numpy.reshape(batchSoftMax, (-1, numUnrollSteps * inputVocabulary))
                                        batchSoftMax = numpy.split(batchSoftMax, numUnrollSteps, 1)
                                        predVal = numpy.empty((batchSize), dtype=numpy.int32)
                                        for j in range(batchSize):
                                            predVal[j] = numpy.random.choice(inputVocabulary, size=None, replace=False, p=batchSoftMax[s][j])
                                    for j in range(batchSize):
                                        h[i][j] = inputId2Word[predVal[j]]
                                    if (s < numUnrollSteps-1):
                                        x[:,s+1] = predVal
                                    else:
                                        x[:,0] = predVal
                                        state = batchFinalState
                                        i += 1
                            # Sampling Mode 1 means we sample in blocks of numUnrollSteps
                            if samplingMode == 1:
                                for i in range(maxHallucinations // numUnrollSteps):
                                    batchPredictions, batchSoftMax, batchFinalState, _ = sess.run([prediction_indices, softmax, finalState, tf.no_op()], {input_data: x, targets: y, initialState: state})
                                    if samplingKind == 0:
                                        # We sample max (comes out of the TF graph)
                                        predVal = batchPredictions
                                    if samplingKind == 1:
                                        # We sample randomly, using the softmax probs in numpy
                                        batchSoftMax = numpy.reshape(batchSoftMax, (-1, numUnrollSteps * inputVocabulary))
                                        batchSoftMax = numpy.split(batchSoftMax, numUnrollSteps, 1)
                                        predVal = numpy.empty((batchSize, numUnrollSteps), dtype=numpy.int32)
                                        for k in range(numUnrollSteps):
                                            for j in range(batchSize):
                                                predVal[j,k] = numpy.random.choice(inputVocabulary, size=None, replace=False, p=batchSoftMax[k][j])
                                    for j in range(batchSize):
                                        for k in range(numUnrollSteps):
                                            h[i*numUnrollSteps + k][j] = inputId2Word[predVal[j, k]]
                                    x = predVal
                                    state = batchFinalState

                            outputFile = open(
                                "Hallucation{0}Run{1}Fold{2}Dataset{3}batchSize{4}unrollSteps{5}hiddenSize{6}InitializationType{7}SamplingMode{8}SamplingKind{9}SeedLength{10}.csv".format(
                                    time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()), runNum, fold, dataset[0],batchSize, numUnrollSteps, hiddenSize,samplingSeedType,samplingMode,samplingKind,samplingSeedLen), "w")
                            outputFile.write("CaseID, Start, Complete, Activity, Resource, Role\n")
#                            outputFile.write("samplingSeedType: {0}\n".format(samplingSeedType))
#                            outputFile.write("SamplingMode: {0}\n".format(samplingMode))
#                            outputFile.write("SamplingKind: {0}\n".format(samplingKind))
#                            outputFile.write("SamplingSeedLength: {0}\n\n".format(samplingSeedLen))
                            for i in range(batchSize):
                                thisCase = 0
                                for j in range(maxHallucinations):
                                    if (h[j][i] == endOfCaseChar):
                                        thisCase += 1
                                    else:
                                        if thisCase != 0:
                                            outputFile.write("{0}, , {1}, {2}, , \n".format(10000*i + thisCase, time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), h[j][i]))
#                                    outputFile.write(h[j][i]+" ")
#                                outputFile.write("\n")
#                            outputFile.write("\n")
                            outputFile.close()
#                outputFile.close()

            if exportEmbedding==1:
                outputFile = open("Vocab{0}Run{1}Fold{2}Dataset{3}batchSize{4}unrollSteps{5}hiddenSize{6}.csv".format(
                    time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()), runNum, fold, dataset[0], batchSize, numUnrollSteps, hiddenSize), "w")
                for idx in sorted(inputId2Word.keys()):
                    outputFile.write("{0}, ".format(inputId2Word[idx]))
                outputFile.close()

                embeddingTensor = sess.run(embedding)
                numpy.savetxt("Embedding{0}Run{1}Fold{2}Dataset{3}batchSize{4}unrollSteps{5}hiddenSize{6}.csv".format(
                    time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()), runNum, fold, dataset[0], batchSize, numUnrollSteps, hiddenSize), embeddingTensor, delimiter=",")

            if exportStates==1:
                # Validation for Visualization starts here. Differences are the use of validwords for data and tf.no_op() instead of the train_op() built into the graph
                if (dataOverlap == False):
                    validateNumBatches = ((len(inputValidateWords) // batchSize) - 1) // numUnrollSteps
                else:
                    validateNumBatchs = ((len(inputValidateWords) // batchSize - numUnrollSteps))
                # state = initialState.eval(session=sess)
                state = finalTrainingState
                stateCollection = numpy.zeros([batchSize, validateNumBatches * numUnrollSteps * cell.state_size ])
                inputCollection = numpy.zeros([batchSize, validateNumBatches * numUnrollSteps], dtype=int)
                for batchNum, (x, y) in enumerate(readWords.words_iterator(inputValidateWords, targetValidateWords, batchSize, numUnrollSteps, dataOverlap)):
                    state, batchStates, _ = sess.run([finalState, states_concat, tf.no_op()], {input_data: x, targets: y, initialState: state})
                    if batchNum % (validateNumBatches // 10) == 10:
                        print("Validation Epoch percent: %.3f" % (batchNum * 1.0 / validateNumBatches))
                    if (dataOverlap==False):
                        stateCollection[:,(batchNum*numUnrollSteps*cell.state_size):((batchNum+1)*numUnrollSteps*cell.state_size)] = batchStates
                        inputCollection[:,(batchNum*numUnrollSteps):(batchNum+1)*numUnrollSteps] = x
                    else:
                        stateCollection[:,(batchNum*cell.state_size):((batchNum+1)*cell.state_size)]
                        inputCollection[:,(batchNum):(batchNum+1)] = x

                stateCollection = numpy.reshape(stateCollection, [-1, cell.state_size])
                # write validation to hdf5 for visualization
                stateFile = h5py.File("StatesOutputs{0}Run{1}Fold{2}.h5".format(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()), runNum, fold), "w")
                for i in range(numLayers):
                    dset = stateFile.create_dataset("states{0}".format(i + 1), shape=(stateCollection.shape[0], stateCollection.shape[1] // numLayers // 2), dtype="f")
                    dset[:] = stateCollection[:,(i*hiddenSize*2):(i*hiddenSize*2)+hiddenSize]
                    dset = stateFile.create_dataset("outputs{0}".format(i + 1), shape=(stateCollection.shape[0], stateCollection.shape[1] // numLayers // 2), dtype="f")
                    dset[:] = stateCollection[:,(i*hiddenSize*2)+hiddenSize:(i*hiddenSize*2)+2*hiddenSize]
                stateFile.close()

                inputCollection = numpy.reshape(inputCollection, [-1, 1])
                inputCollection = inputCollection.flatten()
                wordFile = h5py.File("Words{0}Run{1}Fold{2}.h5".format(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()), runNum, fold), "w")
                dSet = wordFile.create_dataset("words", data=inputCollection)
                wordFile.close()
                readWords.write_dict("word.dict", inputWord2Id)
                configFile = open("lstm.yml", "w")
                configFile.write("name: " + dataset[0] + "\n")
                configFile.write("description: {0} layers, hiddenSize={1}, unrollSteps={2}, dropoutProb={3}\n".format(numLayers, hiddenSize, numUnrollSteps, dropoutProb))
                configFile.write("files: \n")
                configFile.write("  states: StatesOutputs{0}Run{1}Fold{2}.h5\n".format(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()), runNum, fold))
                configFile.write("  word_ids: Words{0}Run{1}Fold{2}.h5\n".format(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()), runNum, fold))
                configFile.write("  dict: word.dict\n")
                configFile.write("word_sequence: \n")
                configFile.write("  file: word_ids \n")
                configFile.write("  path: words \n")
                configFile.write("  dict_file: dict \n")
                configFile.write("states: \n")
                configFile.write("  file: states\n")
                configFile.write("  types: [ \n")
                configFile.write("    {type: state, layer: 1, path: states1}, \n")
                configFile.write("    {type: state, layer: 2, path: states2}, \n")
                configFile.write("    {type: output, layer: 1, path: outputs1}, \n")
                configFile.write("    {type: output, layer: 2, path: outputs2} \n")
                configFile.write("  ]")
                configFile.close()

            if (predictSuffixes > 0) and (inputVocabulary < 27) and (targetVocabulary==inputVocabulary):

                outputFile = open(
                    "SuffixPrediction{0}Run{1}Fold{2}Dataset{3}batchSize{4}unrollSteps{5}hiddenSize{6}SamplingKind{7}.txt".format(
                        time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()), runNum, fold, dataset[0],
                        batchSize, numUnrollSteps, hiddenSize, samplingKind), "w")
                outputFile.write("NormedDLDistance\n")

                print("Suffix Prediction ...\n")

                # Suffix prediction starts here,
                endOfCaseId = inputWord2Id[endOfCaseChar]
                # An zero array for the targets (which we don't need)
                y = numpy.zeros((batchSize, numUnrollSteps), dtype=numpy.int32)
                # We need to convert to letters for levenshtein distance testing
                num2alpha = dict(zip(range(0, 26), string.ascii_lowercase))
                # need to know how many batches we're sending through.
                numBatches = min(readWords.numBatches(os.path.join(path, dataset[0]), batchSize, numUnrollSteps), maxNumBatches)
                # Similarity results go in here
                dldist = list()

                for batchNum in range(numBatches):

                    prefixes, suffixes = readWords.readBatchOfSentences(os.path.join(path, dataset[0]), batchSize, numUnrollSteps, batchNum * batchSize)
                    # convert to IDs
                    prefixes = [[inputWord2Id[element] for element in prefix] for prefix in prefixes]
                    suffixes = [[inputWord2Id[element] for element in suffix] for suffix in suffixes]

                    # get an initial state
                    state = initialState.eval(session=sess)
                    # seed our generation with the initial state and our trace prefixes
                    batchPredictions, batchFinalState, _ = sess.run([prediction_indices, finalState, tf.no_op()], {input_data: prefixes, targets: y, initialState: state})
                    # get the output state so we can feed it back in in the next step
                    state = batchFinalState
                    x = batchPredictions
                    # collect the predictions for this batch in here
                    h = numpy.empty((batchSize, maxSuffixLen), dtype=numpy.int32)
                    # We sample element wise
                    for i in range(maxSuffixLen):
                        s = i % numUnrollSteps
                        batchPredictions, batchSoftMax, batchFinalState, _ = sess.run([prediction_indices, softmax, finalState, tf.no_op()], {input_data: x, targets: y, initialState: state})
                        if samplingKind == 0:
                            # We sample max (comes out of the TF graph)
                            predVal = batchPredictions[:, s]
                        if samplingKind == 1:
                            # We sample randomly, using the softmax probs in numpy
                            batchSoftMax = numpy.reshape(batchSoftMax, (-1, numUnrollSteps * inputVocabulary))
                            batchSoftMax = numpy.split(batchSoftMax, numUnrollSteps, 1)
                            predVal = numpy.full((batchSize), fill_value=endOfCaseId, dtype=numpy.int32)
                            for j in range(batchSize):
                                predVal[j] = numpy.random.choice(inputVocabulary, size=None, replace=False, p=batchSoftMax[s][j])

                        for j in range(batchSize):
                            h[j][i] = predVal[j]
                        if (s < numUnrollSteps - 1):
                            x[:, s + 1] = predVal
                        else:
                            x[:, 0] = predVal
                            state = batchFinalState
                            i += 1

                    suffixes_predicted = numpy.empty((batchSize), dtype='object')
                    # from http://stackoverflow.com/questions/28169671/remove-all-elements-from-a-list-after-a-particular-value
                    for j in range(batchSize):
                        try:
                            eocIndex = h[j].tolist().index(endOfCaseId) + 1
                        except ValueError, e:
                            eocIndex = None
                        suffixes_predicted[j] = h[j].tolist()[:eocIndex]

                    suffixes_predicted_alpha = [''.join([num2alpha[element] for element in suffix]) for suffix in suffixes_predicted]
                    suffixes_alpha = [''.join([num2alpha[element] for element in suffix]) for suffix in suffixes]

                    for j in range(batchSize):
                        outputFile.write("{0}\n".format(dl.normalized_damerau_levenshtein_distance(suffixes_predicted_alpha[j], suffixes_alpha[j])))

                    outputFile.flush()
                    os.fsync(outputFile.fileno())
                    print("Batch {} of {} ".format(batchNum, numBatches))

                outputFile.close()

        resultFile.write("\n")
        resultFile.flush()
        os.fsync(resultFile.fileno())

resultFile.close()

