import os.path
import sys
import tkMessageBox
import numpy as np

from xml.sax import make_parser, handler
from utils import XESAttributeTypes, is_number, is_date


def on_closing():
    pass


class XESTerminateParseException(Exception):

    def __init__(self, value):
        self.value = value

    def __str__(self):
        return repr(self.value)


class XESHandler(handler.ContentHandler):

    def __init__(self, numTraces, minPrefixLen):
        handler.ContentHandler.__init__(self)

        self.numTraces = numTraces
        self.minPrefixLen = minPrefixLen
        self.traceList = []
        self.traceNum = 0

        self.classifiers = dict()
        self.traceGlobals = dict()
        self.eventGlobals = dict()

        self._inTrace = False
        self._inEvent = False
        self._inTraceGlobal = False
        self._inEventGlobal = False
        self._discardTrace = False

        self._thisTraceEvents = None
        self._thisTraceAttribs = None
        self._thisEvent = None

    def startElement(self, name, attrs):
        if name == 'classifier' and not self._inEvent and not self._inTrace:
            if 'name' in attrs:
                attrName = attrs.get('name')
                if 'keys' in attrs:
                    keyNames = attrs.get('keys').split()
                    if len(keyNames) > 1:
                        self.classifiers[attrName] = keyNames
            return
        if name == 'global' and 'scope' in attrs and not self._inEvent and not self._inTrace:
            if attrs['scope'] == 'event':
                self._inEventGlobal = True
            if attrs['scope'] == 'trace':
                self._inTraceGlobal = True
            return
        if name == 'event' and not self._discardTrace:
            self._inEvent = True
            self._thisEvent = dict()
            return
        if name == 'trace':
            self._inTrace = True
            self._thisTraceAttribs = dict()
            self._thisTraceEvents = np.empty((len(self.traceGlobals) + len(self.eventGlobals), 0), dtype=np.object_)
            self._discardTrace = False
            return

        if self._inEventGlobal:
            # global attributes for all events
            if 'key' in attrs:
                attrName = attrs.get('key')
                attrType = None
                if name == "int" or name == "float":
                    attrType = XESAttributeTypes.NUMERIC
                if name == "string" or name == "id" or name == "boolean" or name == "list":
                    attrType = XESAttributeTypes.CATEGORICAL
                if name == "date":
                    attrType = XESAttributeTypes.DATE
                self.eventGlobals[attrName] = attrType
            return
        elif self._inTraceGlobal:
            # global attributes for all events
            if 'key' in attrs:
                attrName = attrs.get('key')
                attrType = None
                if name == "int" or name == "float":
                    attrType = XESAttributeTypes.NUMERIC
                if name == "string" or name == "id" or name == "boolean" or name == "list":
                    attrType = XESAttributeTypes.CATEGORICAL
                if name == "date":
                    attrType = XESAttributeTypes.DATE
                self.traceGlobals[attrName] = attrType
            return
        elif self._inTrace and not self._inEvent:
            # case attributes
            if 'key' in attrs:
                attrName = attrs.get('key')
                if attrName in self.traceGlobals and 'value' in attrs:
                    attrValue = attrs.get('value')
                    if name == "int" or name == "float" or self.traceGlobals[attrName] == XESAttributeTypes.NUMERIC:
                        self._discardTrace = not is_number(attrValue)
                    if name == "date" or self.traceGlobals[attrName] == XESAttributeTypes.DATE:
                        self._discardTrace = not is_date(attrValue)
                    self._thisTraceAttribs[(attrName, False)] = attrValue
            return
        elif self._inTrace and self._inEvent:
            # event attributes
            if 'key' in attrs:
                attrName = attrs.get('key')
                if attrName in self.eventGlobals and 'value' in attrs:
                    attrValue = attrs.get('value')
                    # Check that the attribute value conforms to the attribute type (if not a string type)
                    if name == "int" or name == "float" or self.eventGlobals[attrName] == XESAttributeTypes.NUMERIC:
                        self._discardTrace = not is_number(attrValue)
                    if name == "date" or self.eventGlobals[attrName] == XESAttributeTypes.DATE:
                        self._discardTrace = not is_date(attrValue)
                    self._thisEvent[(attrName, True)] = attrValue

    def endElement(self, name):
        if name == 'event':
            self._inEvent = False

            # Let's check if an error occurred in the event processing
            if not self._discardTrace:
                # We are done with the event, valid attribute values are in _thisEvent
                # We add the trace-level attributes to this event
                self._thisEvent.update(self._thisTraceAttribs)
                # Next, we check if the event is complete
                if len(self._thisEvent) != len(self.eventGlobals) + len(self.traceGlobals):
                    # if it isn't, we will discard the entire trace
                    self._discardTrace = True
                else:
                    # We add the complete event to others in this case
                    event = np.asarray([val for (_, val) in sorted(self._thisEvent.items())])
                    event.shape = (len(self.eventGlobals) + len(self.traceGlobals), 1)
                    self._thisTraceEvents = np.concatenate((self._thisTraceEvents, event), axis=1)
            return

        if name == 'trace':
            self._inTrace = False

            # if we have a complete set of trace-level attributes and all events are also complete
            if len(self._thisTraceAttribs) == len(self.traceGlobals) and not self._discardTrace:
                if np.shape(self._thisTraceEvents)[1] >= self.minPrefixLen:
                    self.traceList.append(self._thisTraceEvents)
                    self.traceNum += 1
                    if self.traceNum == self.numTraces:
                        raise XESTerminateParseException("Done after {0} traces".format(self.traceNum))
                else:
                    sys.stderr.write("Skipped a trace that is too short\n")
            return

        if name == 'global' and self._inEventGlobal:
            self._inEventGlobal = False
            return

        if name == 'global' and self._inTraceGlobal:
            self._inTraceGlobal = False
            return


class XESParserPrediction:
    def __init__(self, xespath, numTraces, minPrefixLen):
        self.traces = None
        self.attribs = None
        self.classifiers = None

        self._numTraces = numTraces
        self._minPrefixLen = minPrefixLen
        self._filepath = xespath
        self._saxparser = None

        if not os.path.isfile(self._filepath):
            tkMessageBox.showerror("IOError", "Cannot open file "+self._filepath)

    def getAttributes(self):
        if self._saxparser is None:
            self._parse()
        return self.attribs

    def getTraces(self):
        if self._saxparser is None:
            self._parse()
        return self.traces

    def getClassifiers(self):
        if self._saxparser is None:
            self._parse()
        return self.classifiers

    def getXESPath(self):
        return self._filepath

    def _parse(self):

        self._saxparser = make_parser()
        self._saxparser.setContentHandler(XESHandler(self._numTraces, self._minPrefixLen))
        try:
            self._saxparser.parse(self._filepath)
        except XESTerminateParseException:
            pass

        # Get the results from the contentHandler
        eventGlobals = self._saxparser.getContentHandler().eventGlobals
        traceGlobals = self._saxparser.getContentHandler().traceGlobals
        self.classifiers = self._saxparser.getContentHandler().classifiers
        self.traces = self._saxparser.getContentHandler().traceList
        self.attribs = dict(zip(sorted(zip(eventGlobals, [True]*len(eventGlobals)) + zip(traceGlobals, [False]*len(traceGlobals))), range(0, len(eventGlobals)+len(traceGlobals))))

    def read(self):
        self._parse()
