import numpy as np
import re
import datetime as dt
from utils import XESAttributeTypes


class XESPredictionData:

    def __init__(self, parser):
        self.traces = parser.getTraces()
        self.attribs = parser.getAttributes()
        self.classifiers = parser.getClassifiers()
        self.prefixLength = None

    def _checkData(self, graphConfig):
        # Make sure the XES File contains all the attributes we need
        for (attrnames, _) in graphConfig.targets:
            if len(attrnames) > 1:
                if attrnames not in self.classifiers.values():
                    return False
            else:
                if (attrnames[0], True) not in self.attribs:
                    return False

        for (attrnames, eventattr) in graphConfig.predictors:
            if len(attrnames) > 1:
                if attrnames not in self.classifiers.values():
                    return False
            else:
                if (attrnames[0], eventattr) not in self.attribs:
                    return False
        return True

    def _chopData(self):
        self.prefixLength = min([trace.shape[1] for trace in self.traces])
        self.traces = [trace[:, 0:self.prefixLength] for trace in self.traces]

    def getAttrType(self, preds, names, event):
        if isinstance(names, tuple):
            names = list(names)
        if not isinstance(names, list):
            names = [names]
        for (attrnames, attrtype, embedsize, val2id, id2val, rowNum, eventAttrib) in preds:
            if attrnames == names and eventAttrib == event:
                return attrtype, val2id

    def getClassifier(self, names):
        for (clsname, attrnames) in self.classifiers.items():
            if attrnames == names:
                return clsname

    def finalizeData(self, config):

        if self._checkData(config):
            self._chopData()

            for (clsName, attrNames) in self.classifiers.iteritems():
                attrtype, _ = self.getAttrType(config.preds, attrNames, True)

                for t, trace in enumerate(self.traces):
                    row = None
                    if attrtype == XESAttributeTypes.CATEGORICAL:
                        for n, name in enumerate(attrNames):
                            row = trace[self.attribs[(name, True)]] if n == 0 else row + trace[self.attribs[(name, True)]]
                    if attrtype == XESAttributeTypes.NUMERIC:
                        for n, name in enumerate(attrNames):
                            row = trace[self.attribs[(name, True)]].astype(float) if n == 0 else row * trace[self.attribs[(name, True)]].astype(float)
                        row = row.astype(str)
                    trace = np.concatenate((trace, np.reshape(row, (1, -1))), axis=0)
                    rowNum = trace.shape[0] - 1
                    self.attribs[(tuple(attrNames), True)] = rowNum
                    self.traces[t] = trace

            # do this for all traces
            for t, trace in enumerate(self.traces):
                trace = [np.squeeze(tmp) for tmp in np.split(trace, trace.shape[0], axis=0)]
                # Do this for all attributes
                for ((names, eventAttrib), rowNum) in self.attribs.iteritems():
                    # get the attribute type and the val2id dictionary from the training data!
                    attrtype, val2id = self.getAttrType(config.preds, names, eventAttrib)

                    if attrtype == XESAttributeTypes.CATEGORICAL:
                        for k, v in val2id.iteritems():
                            trace[rowNum][trace[rowNum] == k] = v
                        trace[rowNum] = trace[rowNum].astype(dtype='int64')
                    if attrtype == XESAttributeTypes.NUMERIC:
                        trace[rowNum][trace[rowNum] == '[EOC]'] = '0'
                        trace[rowNum] = trace[rowNum].astype(dtype='float32')
                    if attrtype == XESAttributeTypes.DATE:
                        trace[rowNum] = trace[rowNum].astype(dtype='datetime64')
                        tmpcopy = np.empty((len(trace[rowNum])), dtype=np.object_)
                        tmpcopy[0] = np.timedelta64(0, 's')
                        i = 1
                        while i < len(trace[rowNum]):
                            if trace[rowNum][i] != np.datetime64('0'):
                                tmpcopy[i] = trace[rowNum][i] - trace[rowNum][i-1]
                            else:
                                tmpcopy[i] = np.timedelta64(0, 's')
                                if i + 1 < len(trace[rowNum]):
                                    tmpcopy[i+1] = np.timedelta64(0, 's')
                                    i += 1
                            i += 1
                        tmpcopy = tmpcopy.astype('timedelta64')
                        tmpcopy = np.divide(tmpcopy, np.timedelta64(1, 's'))
                        trace[rowNum] = tmpcopy.astype(dtype='float32')
                self.traces[t] = trace

    def getFeedDict(self, numUnrollSteps, startstep, tfGraph, graphConfig):

        feeddict = dict()
        predIndex = None
        for (attrnames, eventAttrib) in graphConfig.predictors:
            inputTensor = tfGraph.inputPlaceholders[(re.sub("[^A-Za-z0-9.\\-]*", "", "-".join(attrnames)), eventAttrib)]
            r = np.zeros(shape=inputTensor.get_shape(), dtype=inputTensor.dtype.as_numpy_dtype)
            if len(attrnames) > 1:
                rowNum = self.attribs[(tuple(attrnames), eventAttrib)]
            else:
                rowNum = self.attribs[(attrnames[0], eventAttrib)]

            for t, trace in enumerate(self.traces):
                for s in range(0, numUnrollSteps):
                    if startstep + s < len(trace[rowNum]):
                        r[t, s] = trace[rowNum][startstep+s]
                    if startstep + s + 1 >= len(trace[rowNum]):
                        predIndex = s
                        break

            feeddict[inputTensor] = r

        return feeddict, predIndex

    def addPrediction(self, nnoutputs, softMaxs, predIndex, graphConfig):
        for (attrnames, eventAttrib) in graphConfig.targets:
            attrtype, val2id = self.getAttrType(graphConfig.preds, attrnames, eventAttrib)
            if len(attrnames) > 1:
                rowNum = self.attribs[(tuple(attrnames), eventAttrib)]
            else:
                rowNum = self.attribs[(attrnames[0], eventAttrib)]
            if attrtype == XESAttributeTypes.NUMERIC:
                nnoutput = nnoutputs[re.sub("[^A-Za-z0-9.\\-]*", "", "-".join(attrnames))]
                nnoutput = np.split(nnoutput, graphConfig.numUnrollSteps, 1)
                nnoutput = nnoutput[predIndex]
                for i in range(0, len(self.traces)):
                    self.traces[i][rowNum] = np.resize(self.traces[i][rowNum], new_shape=(self.traces[i][rowNum].shape[0]+1))
                    self.traces[i][rowNum][-1] = nnoutput[i]
            if attrtype == XESAttributeTypes.DATE:
                nnoutput = nnoutputs[re.sub("[^A-Za-z0-9.\\-]*", "", "-".join(attrnames))]
                nnoutput = np.split(nnoutput, graphConfig.numUnrollSteps, 1)
                nnoutput = nnoutput[predIndex]
                for i in range(0, len(self.traces)):
                    self.traces[i][rowNum] = np.resize(self.traces[i][rowNum], new_shape=(self.traces[i][rowNum].shape[0]+1))
                    self.traces[i][rowNum][-1] = nnoutput[i] + self.traces[i][rowNum][-2]
            if attrtype == XESAttributeTypes.CATEGORICAL:
                softMax = softMaxs[re.sub("[^A-Za-z0-9.\\-]*", "", "-".join(attrnames))]
                softMax = np.split(softMax, graphConfig.numUnrollSteps, 0)
                softMax = softMax[predIndex]
                predVal = np.zeros(shape=(len(self.traces)), dtype=np.int32)
                for i in range(0, len(self.traces)):
                    predVal[i] = np.random.choice(val2id.values(), size=None, replace=False, p=softMax[i])
                    self.traces[i][rowNum] = np.resize(self.traces[i][rowNum], new_shape=(self.traces[i][rowNum].shape[0]+1))
                    self.traces[i][rowNum][-1] = predVal[i]

    def writeToXES(self, xesOutPath, startPos, numPSteps, graphConfig, eocDetect):
        with open(xesOutPath, "w") as f:
            f.write('<?xml version="1.0" encoding="UTF-8" ?>\n')
            f.write('<log xes.version="1.0" xes.features="nested-attributes" openxes.version="1.0RC7" xmlns="http://www.xes-standard.org/">\n')
            f.write('	<extension name="Lifecycle" prefix="lifecycle" uri="http://www.xes-standard.org/lifecycle.xesext"/>\n')
            f.write('   <extension name="Organizational" prefix="org" uri="http://www.xes-standard.org/org.xesext"/>\n')
            f.write('   <extension name="Time" prefix="time" uri="http://www.xes-standard.org/time.xesext"/>\n')
            f.write('	<extension name="Concept" prefix="concept" uri="http://www.xes-standard.org/concept.xesext"/>\n')
            f.write('   <global scope="event">\n')
            for (attrnames, attrtype, _, _, _, _, eventAttrib) in graphConfig.preds:
                if eventAttrib and (attrnames, eventAttrib) in graphConfig.targets:
                    if len(attrnames) == 1:
                        if attrtype == XESAttributeTypes.CATEGORICAL:
                            f.write('       <string key="'+attrnames[0]+'" value="UNKNOWN"/>\n')
                        if attrtype == XESAttributeTypes.NUMERIC:
                            f.write('       <float key="'+attrnames[0]+'" value="0.0"/>\n')
                        if attrtype == XESAttributeTypes.DATE:
                            f.write('       <date key="'+attrnames[0]+'" value="1970-01-01T00:00:00+01:00"/>\n')
                    else:
                        clsname = self.getClassifier(attrnames)
                        if attrtype == XESAttributeTypes.CATEGORICAL:
                            f.write('       <string key="'+clsname+'" value="UNKNOWN"/>\n')
                        if attrtype == XESAttributeTypes.NUMERIC:
                            f.write('       <float key="'+clsname+'" value="0.0"/>\n')
                        if attrtype == XESAttributeTypes.DATE:
                            f.write('       <date key="'+clsname+'" value="1970-01-01T00:00:00+01:00"/>\n')
            f.write('   </global>\n')
            for t in self.traces:
                f.write('   <trace>\n')
                for s in range(startPos, startPos+numPSteps):
                    # check for EOC (really only works categorical, as dates and numeric ones may have valid zeros in the trace
                    if eocDetect:
                        eoc = False
                        for (attrnames, eventAttrib) in graphConfig.targets:
                            attrtype, val2id = self.getAttrType(graphConfig.preds, attrnames, eventAttrib)
                            id2val = dict([(v, k) for (k, v) in val2id.iteritems()])
                            if attrtype == XESAttributeTypes.CATEGORICAL:
                                if len(attrnames) > 1:
                                    rowNum = self.attribs[(tuple(attrnames), eventAttrib)]
                                else:
                                    rowNum = self.attribs[(attrnames[0], eventAttrib)]
                                if id2val[t[rowNum][s]] == '[EOC]' * len(attrnames):
                                    eoc = True
                        if eoc:
                            break
                    f.write('       <event>\n')
                    for (attrnames, eventAttrib) in graphConfig.targets:
                        attrtype, val2id = self.getAttrType(graphConfig.preds, attrnames, eventAttrib)
                        id2val = dict([(v, k) for (k, v) in val2id.iteritems()])
                        if len(attrnames) > 1:
                            rowNum = self.attribs[(tuple(attrnames), eventAttrib)]
                            clsname = self.getClassifier(attrnames)
                            if attrtype == XESAttributeTypes.CATEGORICAL:
                                f.write('          <string key="'+clsname+'" value="'+id2val[t[rowNum][s]].encode('ascii', 'ignore')+'"/>\n')
                            if attrtype == XESAttributeTypes.NUMERIC:
                                (move, scale) = graphConfig.scale[(tuple(attrnames), eventAttrib)]
                                f.write('           <float key="'+clsname+'" value="'+t[rowNum][s]*scale+move+'"/>\n')
                            if attrtype == XESAttributeTypes.DATE:
                                (move, scale) = graphConfig.scale[(tuple(attrnames), eventAttrib)]
                                f.write('           <date key="'+clsname+'" value="' + dt.datetime.fromtimestamp(int(t[rowNum][s] * scale + move)).isoformat() + '"/>\n')
                        else:
                            rowNum = self.attribs[(attrnames[0], eventAttrib)]
                            if attrtype == XESAttributeTypes.CATEGORICAL:
                                f.write('           <string key="'+attrnames[0]+'" value="'+id2val[t[rowNum][s]].encode('ascii', 'ignore')+'"/>\n')
                            if attrtype == XESAttributeTypes.NUMERIC:
                                (move, scale) = graphConfig.scale[(tuple(attrnames), eventAttrib)]
                                f.write('           <float key="'+attrnames[0]+'" value="'+t[rowNum][s]*scale+move+'"/>\n')
                            if attrtype == XESAttributeTypes.DATE:
                                (move, scale) = graphConfig.scale[(tuple(attrnames), eventAttrib)]
                                f.write('           <date key="'+attrnames[0]+'" value="' + dt.datetime.fromtimestamp(int(t[rowNum][s]*scale+move)).isoformat() + '"/>\n')
                    f.write('       </event>\n')
                f.write('   </trace>\n')
            f.write('</log>\n')
            f.close()
