import os.path
import tkMessageBox
import numpy as np
import Tkinter
import ttk

from xml.sax import make_parser, handler
from utils import XESAttributeTypes, is_number, is_date


def on_closing():
    pass


class XESHandler(handler.ContentHandler):
    def __init__(self, progressbar, validcountvariable, invalidcountvariable, validcountlabel, invalidcountlabel):
        handler.ContentHandler.__init__(self)

        self.validcountvariable = validcountvariable
        self.invalidcountvariable = invalidcountvariable
        self.validcountlabel = validcountlabel
        self.invalidcountlabel = invalidcountlabel
        self.progressbar = progressbar

        self.classifiers = dict()
        self.traceGlobals = dict()
        self.eventGlobals = dict()
        self.events = None

        self._inTrace = False
        self._inEvent = False
        self._inTraceGlobal = False
        self._inEventGlobal = False
        self._discardTrace = False

        self._thisTraceEvents = None
        self._thisTraceAttribs = None
        self._thisEvent = None

    def startElement(self, name, attrs):

        if name == 'classifier' and not self._inEvent and not self._inTrace:
            if 'name' in attrs:
                attrName = attrs.get('name')
                if 'keys' in attrs:
                    keyNames = attrs.get('keys').split()
                    if len(keyNames) > 1:
                        self.classifiers[attrName] = keyNames
            return
        if name == 'global' and 'scope' in attrs and not self._inEvent and not self._inTrace:
            if attrs['scope'] == 'event':
                self._inEventGlobal = True
            if attrs['scope'] == 'trace':
                self._inTraceGlobal = True
            return
        if name == 'event' and not self._discardTrace:
            self._inEvent = True
            self._thisEvent = dict()
            return
        if name == 'trace':
            self._inTrace = True
            self._thisTraceAttribs = dict()
            self._thisTraceEvents = np.empty((len(self.traceGlobals) + len(self.eventGlobals), 0), dtype=np.object_)
            self._discardTrace = False
            return

        if self._inEventGlobal:
            # global attributes for all events
            if 'key' in attrs:
                attrName = attrs.get('key')
                attrType = None
                if name == "int" or name == "float":
                    attrType = XESAttributeTypes.NUMERIC
                if name == "string" or name == "id" or name == "boolean" or name == "list":
                    attrType = XESAttributeTypes.CATEGORICAL
                if name == "date":
                    attrType = XESAttributeTypes.DATE
                self.eventGlobals[attrName] = attrType
            return
        elif self._inTraceGlobal:
            # global attributes for all events
            if 'key' in attrs:
                attrName = attrs.get('key')
                attrType = None
                if name == "int" or name == "float":
                    attrType = XESAttributeTypes.NUMERIC
                if name == "string" or name == "id" or name == "boolean" or name == "list":
                    attrType = XESAttributeTypes.CATEGORICAL
                if name == "date":
                    attrType = XESAttributeTypes.DATE
                self.traceGlobals[attrName] = attrType
            return
        elif self._inTrace and not self._inEvent:
            # case attributes
            if 'key' in attrs:
                attrName = attrs.get('key')
                if attrName in self.traceGlobals and 'value' in attrs:
                    attrValue = attrs.get('value')
                    if name == "int" or name == "float" or self.traceGlobals[attrName] == XESAttributeTypes.NUMERIC:
                        self._discardTrace = not is_number(attrValue)
                    if name == "date" or self.traceGlobals[attrName] == XESAttributeTypes.DATE:
                        self._discardTrace = not is_date(attrValue)
                    self._thisTraceAttribs[(attrName, False)] = attrValue
            return
        elif self._inTrace and self._inEvent:
            # event attributes
            if 'key' in attrs:
                attrName = attrs.get('key')
                if attrName in self.eventGlobals and 'value' in attrs:
                    attrValue = attrs.get('value')
                    # Check that the attribute value conforms to the attribute type (if not a string type)
                    if name == "int" or name == "float" or self.eventGlobals[attrName] == XESAttributeTypes.NUMERIC:
                        self._discardTrace = not is_number(attrValue)
                    if name == "date" or self.eventGlobals[attrName] == XESAttributeTypes.DATE:
                        self._discardTrace = not is_date(attrValue)
                    self._thisEvent[(attrName, True)] = attrValue

    def endElement(self, name):
        if name == 'event':
            self._inEvent = False

            # Let's check if an error occurred in the event processing
            if not self._discardTrace:
                # We are done with the event, valid attribute values are in _thisEvent
                # We add the trace-level attributes to this event
                self._thisEvent.update(self._thisTraceAttribs)
                # Next, we check if the event is complete
                if len(self._thisEvent) != len(self.eventGlobals) + len(self.traceGlobals):
                    # if it isn't, we will discard the entire trace
                    self._discardTrace = True
                else:
                    # We add the complete event to others in this case
                    event = np.asarray([val for (_, val) in sorted(self._thisEvent.items())])
                    event.shape = (len(self.eventGlobals) + len(self.traceGlobals), 1)
                    self._thisTraceEvents = np.concatenate((self._thisTraceEvents, event), axis=1)
            return

        if name == 'trace':
            self._inTrace = False

            # if we have a complete set of trace-level attributes and all events are also complete
            if len(self._thisTraceAttribs) == len(self.traceGlobals) and self._thisTraceEvents.shape[1] > 0 and not self._discardTrace:
                eocEvent = dict()
                for (attrName, attrType) in self.eventGlobals.iteritems():
                    if attrType == XESAttributeTypes.NUMERIC:
                        eocEvent[(attrName, True)] = '0'
                    elif attrType == XESAttributeTypes.DATE:
                        eocEvent[(attrName, True)] = '0'
                    else:
                        eocEvent[(attrName, True)] = '[EOC]'
                for (attrName, attrType) in self.traceGlobals.iteritems():
                    if attrType == XESAttributeTypes.NUMERIC:
                        eocEvent[(attrName, False)] = '0'
                    elif attrType == XESAttributeTypes.DATE:
                        eocEvent[(attrName, False)] = '0'
                    else:
                        eocEvent[(attrName, False)] = '[EOC]'
                # We add the complete event to others in this case
                eocEvent = np.asarray([val for (_, val) in sorted(eocEvent.items())])
                eocEvent.shape = (len(self.eventGlobals) + len(self.traceGlobals), 1)
                if self.events is not None:
                    self.events = np.concatenate((self.events, self._thisTraceEvents, eocEvent), axis=1)
                else:
                    self.events = np.concatenate((self._thisTraceEvents, eocEvent), axis=1)
                self.validcountvariable.set(int(self.validcountvariable.get()) + 1)
                self.validcountlabel.update()
            else:
                self.invalidcountvariable.set(int(self.invalidcountvariable.get()) + 1)
                self.invalidcountlabel.update()

            self.progressbar.update()
            return

        if name == 'global' and self._inEventGlobal:
            self._inEventGlobal = False
            return

        if name == 'global' and self._inTraceGlobal:
            self._inTraceGlobal = False
            return

    def endDocument(self):
        pass


class XESParser:
    def __init__(self, xespath):
        self.events = None
        self.attribs = None
        self.classifiers = None

        self._filepath = xespath
        self._saxparser = None

        if not os.path.isfile(self._filepath):
            tkMessageBox.showerror("IOError", "Cannot open file "+self._filepath)

    def getAttributes(self):
        if self._saxparser is None:
            self._parse()
        return self.attribs

    def getEvents(self):
        if self._saxparser is None:
            self._parse()
        return self.events

    def getClassifiers(self):
        if self._saxparser is None:
            self._parse()
        return self.classifiers

    def getXESPath(self):
        return self._filepath

    def _parse(self):

        top = Tkinter.Toplevel()
        top.title("Reading XES File")
        top.resizable(False, False)
        top.protocol("WM_DELETE_WINDOW", on_closing)

        top.configure(padx=5, pady=5)
        top.resizable(False, False)

        lbl = Tkinter.Label(top, text="Progress:", padx=5, pady=5)
        lbl.grid(row=0, column=0, sticky=Tkinter.W)
        progressBar = ttk.Progressbar(top, orient="horizontal", length=200, mode="indeterminate")
        progressBar.grid(row=0, column=1)

        lbl = Tkinter.Label(top, text="Valid traces:", padx=5, pady=5)
        lbl.grid(row=1, column=0, sticky=Tkinter.W)
        validtraceNumVar = Tkinter.StringVar(value='0')
        validcountlbl = Tkinter.Label(top, textvariable=validtraceNumVar, padx=5, pady=5)
        validcountlbl.grid(row=1, column=1, sticky=Tkinter.W)

        lbl = Tkinter.Label(top, text="Invalid or empty traces:", padx=5, pady=5)
        lbl.grid(row=2, column=0, sticky=Tkinter.W)
        invalidtraceNumVar = Tkinter.StringVar(value='0')
        invalidcountlbl = Tkinter.Label(top, textvariable=invalidtraceNumVar, padx=5, pady=5)
        invalidcountlbl.grid(row=2, column=1, sticky=Tkinter.W)

        progressBar.start()
        top.update()

        self._saxparser = make_parser()
        self._saxparser.setContentHandler(XESHandler(progressBar, validtraceNumVar, invalidtraceNumVar, validcountlbl, invalidcountlbl))
        self._saxparser.parse(self._filepath)

        # Get the results from the contentHandler
        eventGlobals = self._saxparser.getContentHandler().eventGlobals
        traceGlobals = self._saxparser.getContentHandler().traceGlobals
        self.classifiers = self._saxparser.getContentHandler().classifiers
        self.events = self._saxparser.getContentHandler().events
        self.attribs = dict(zip(sorted(zip(eventGlobals, [True]*len(eventGlobals)) + zip(traceGlobals, [False]*len(traceGlobals))), range(0, len(eventGlobals)+len(traceGlobals))))

        if self.events is None:
            tkMessageBox.showerror("Error", "No valid traces found.")
            progressBar.stop()
            top.quit()
            top.withdraw()
            return

        progressBar.update()

        # check the attribute types and get the number of unique values
        for ((attribName, eventAttrib), rowNum) in self.attribs.iteritems():
            attribType = eventGlobals[attribName] if eventAttrib else traceGlobals[attribName]
            row = self.events[rowNum, :]
            urow = np.unique(row)
            val2id = dict(zip(urow, range(0, len(urow))))
            id2val = dict(zip(range(0, len(urow)), urow))
            if attribType == XESAttributeTypes.CATEGORICAL:
                erow = row[(row != '[EOC]') & (row != '0')]
                try:
                    erow.astype(dtype=float, copy=False)
                    attribType = XESAttributeTypes.NUMERIC
                except ValueError:
                    try:
                        erow.astype(dtype='datetime64', copy=False)
                        attribType = XESAttributeTypes.DATE
                    except ValueError:
                        pass
            self.attribs[(attribName, eventAttrib)] = (rowNum, attribType, val2id, id2val)
            progressBar.update()

        for (clsName, clsAttribs) in self.classifiers.iteritems():
            possibleNumber = True
            for clsAttrib in clsAttribs:
                (_, attribType, _, _) = self.attribs[(clsAttrib, True)]
                possibleNumber = possibleNumber and (attribType == XESAttributeTypes.NUMERIC)
            self.classifiers[clsName] = (clsAttribs, XESAttributeTypes.NUMERIC if possibleNumber else XESAttributeTypes.CATEGORICAL)
            progressBar.update()

        progressBar.stop()
        top.quit()
        top.withdraw()

    def read(self):
        self._parse()
