import numpy as np
import Tkinter
import ttk
import re
from utils import XESAttributeTypes


class XESData:
    def __init__(self, events):
        p = 5

        self._folds = None
        self._nfolds = None

        self._valid = None
        self._train = None

        self.preds = None
        self.targets = None
        self.predictors = None

        self.events = events

        self.top = Tkinter.Toplevel()
        self.ProgressVar = Tkinter.IntVar(value=0)
        self.CurrentOpVar = Tkinter.StringVar(value='')

        self.top.title("Data Preparation")
        self.top.configure(padx=p, pady=p)
        self.top.resizable(False, False)

        lbl = Tkinter.Label(self.top, text="Progress:", padx=p, pady=p)
        lbl.grid(row=0, column=0, sticky=Tkinter.W)
        progressBar = ttk.Progressbar(self.top, orient="horizontal", length=100, mode="determinate", maximum=100, value=0, variable=self.ProgressVar)
        progressBar.grid(row=0, column=1)
        lbl = Tkinter.Label(self.top, text="Current Op:", padx=p, pady=p)
        lbl.grid(row=1, column=0, sticky=Tkinter.W)
        lbl = Tkinter.Label(self.top, textvariable=self.CurrentOpVar, padx=p, pady=p)
        lbl.grid(row=1, column=1, sticky=Tkinter.W)

    def finalizeData(self, config):
        self.preds = config.preds
        self.targets = config.targets
        self.predictors = config.predictors

        newpreds = []
        rownums = dict([(name[0], rowNum) for (name, _, _, _, _, rowNum, eventAttrib) in self.preds if eventAttrib and len(name) == 1])
        for i, (names, attrtype, embedSize, val2id, id2val, rowNum, eventAttrib) in enumerate(self.preds):

            self.CurrentOpVar.set('Finalizing Data (1) for '+''.join(names))
            self.ProgressVar.set((i*100)//(2*len(self.preds)))
            self.top.update_idletasks()

            if len(names) > 1:
                if attrtype == XESAttributeTypes.CATEGORICAL:
                    for n, name in enumerate(names):
                        row = self.events[rownums[name]] if n == 0 else row + self.events[rownums[name]]
                if attrtype == XESAttributeTypes.NUMERIC:
                    for n, name in enumerate(names):
                        row = self.events[rownums[name]].astype(float) if n == 0 else row * self.events[rownums[name]].astype(float)
                    row = row.astype(str)
                urow = np.unique(row)
                val2id = dict(zip(urow, range(0, len(urow))))
                id2val = dict(zip(range(0, len(urow)), urow))
                self.events = np.concatenate((self.events, np.reshape(row, (1, -1))), axis=0)
                rowNum = self.events.shape[0] - 1
                newpreds.append((names, attrtype, embedSize, val2id, id2val, rowNum, eventAttrib))

        self.preds.extend(newpreds)
        self.preds = [(names, attrtype, embedSize, val2id, id2val, rowNum, eventAttrib) for (names, attrtype, embedSize, val2id, id2val, rowNum, eventAttrib) in self.preds if rowNum != -1]
        config.preds = self.preds

        self.events = [np.squeeze(tmp) for tmp in np.split(self.events, self.events.shape[0], axis=0)]
        for i, (names, attrtype, _, val2id, _, rowNum, eventAttrib) in enumerate(self.preds):
            # This may be expensive is processing dates, so we check whether it needs to be done
            if (names, eventAttrib) in self.predictors or (names, eventAttrib) in self.targets:
                self.CurrentOpVar.set('Finalizing Data (2) for '+''.join(names))
                self.ProgressVar.set(50+(i*100)//(2*len(self.preds)))
                self.top.update_idletasks()

                if attrtype == XESAttributeTypes.CATEGORICAL:
                    for k, v in val2id.iteritems():
                        self.events[rowNum][self.events[rowNum] == k] = v
                    self.events[rowNum] = self.events[rowNum].astype(dtype='int64')
                if attrtype == XESAttributeTypes.NUMERIC:
                    self.events[rowNum][self.events[rowNum] == '[EOC]'] = '0'
                    self.events[rowNum] = self.events[rowNum].astype(dtype='float32')
                    # standardize these and remember the scaling factors
                    move = np.mean(self.events[rowNum])
                    scale = np.std(self.events[rowNum])
                    self.events[rowNum] = (self.events[rowNum] - move)/scale
                    config.scale[(tuple(names), eventAttrib)] = (move, scale)
                if attrtype == XESAttributeTypes.DATE:
                    self.events[rowNum] = self.events[rowNum].astype(dtype='datetime64')
                    tmpcopy = np.empty((len(self.events[rowNum])), dtype=np.object_)
                    tmpcopy[0] = np.timedelta64(0, 's')
                    j = 1
                    while j < len(self.events[rowNum]):
                        if self.events[rowNum][j] != np.datetime64('0'):
                            tmpcopy[j] = self.events[rowNum][i] - self.events[rowNum][i-1]
                        else:
                            tmpcopy[j] = np.timedelta64(0, 's')
                            if j + 1 < len(self.events[rowNum]):
                                tmpcopy[j+1] = np.timedelta64(0, 's')
                                j += 1
                        j += 1

                    tmpcopy = tmpcopy.astype('timedelta64')
                    tmpcopy = np.divide(tmpcopy, np.timedelta64(1, 's'))
                    self.events[rowNum] = tmpcopy.astype(dtype='float32')
                    # standardize these and remember the scaling factors
                    move = 0
                    scale = 1
                    if config.dateScale == 'Standardize':
                        move = np.mean(self.events[rowNum])
                        scale = np.std(self.events[rowNum])
                    elif config.dateScale == 'Days':
                        move = 0
                        scale = 24*60*60
                    elif config.dateScale == 'Hours':
                        move = 0
                        scale = 60*60
                    elif config.dateScale == 'Minutes':
                        move = 0
                        scale = 60
                    self.events[rowNum] = (self.events[rowNum] - move)/scale
                    config.scale[(tuple(names), eventAttrib)] = (move, scale)

            self.top.quit()
        self.top.withdraw()

        return config

    def createDataFolds(self, nfolds):
        self._nfolds = nfolds
        if nfolds > 1:
            self._folds = [np.array_split(events, nfolds) for events in self.events]
        else:
            self._folds = self.events
        pass

    def createTrainValidData(self, validFold):
        if self._nfolds > 1:
            self._valid = [_folds[validFold] for _folds in self._folds]
            self._train = [np.concatenate([_folds[i] for i in range(0, len(_folds)) if i != validFold]) for _folds in self._folds]
        else:
            self._valid = self._train = self._folds
        pass

    def getFeedDict(self, batchSize, numUnrollSteps, startstep, train, tfGraph):
        if train:
            partSize = len(self._train[0]) // batchSize
            if startstep + numUnrollSteps <= partSize:
                r = [np.stack([_train[partNum*partSize+startstep:partNum*partSize+startstep+numUnrollSteps] for partNum in range(0, batchSize)], axis=0) for _train in self._train]
                t = [np.stack([_train[partNum*partSize+startstep+1:partNum*partSize+startstep+numUnrollSteps+1] for partNum in range(0, batchSize)], axis=0) for _train in self._train]
            else:
                raise StopIteration('No more training data')
        else:
            partSize = len(self._valid[0]) // batchSize
            if startstep + numUnrollSteps <= partSize:
                r = [np.stack([_valid[partNum*partSize+startstep:partNum*partSize+startstep+numUnrollSteps] for partNum in range(0, batchSize)], axis=0) for _valid in self._valid]
                t = [np.stack([_valid[partNum*partSize+startstep+1:partNum*partSize+startstep+numUnrollSteps+1] for partNum in range(0, batchSize)], axis=0) for _valid in self._valid]
            else:
                raise StopIteration('No more validation data')

        feeddict = dict()
        for (names, attrtype, _, val2id, _, rowNum, eventAttrib) in self.preds:
            if (names, eventAttrib) in self.predictors:
                inputTensor = tfGraph.inputPlaceholders[(re.sub("[^A-Za-z0-9.\\-]*", "", "-".join(names)), eventAttrib)]
                inputData = r[rowNum]
                feeddict[inputTensor] = inputData
            if (names, eventAttrib) in self.targets:
                targetTensor = tfGraph.targetPlaceholders[re.sub("[^A-Za-z0-9.\\-]*", "", "-".join(names))]
                targetData = t[rowNum]
                feeddict[targetTensor] = targetData

        return feeddict
