import Tkinter
import tkFileDialog
import tkMessageBox
import ttk
import math
import pickle
from utils import XESAttributeTypes, is_number, is_posint


def on_closing():
    pass


class ConfigurationDialog(object):

    def go_action(self):
        npreds = 0
        ntargs = 0
        for (pred, _, _, targ) in self.classifierConfigVars.values()+self.eventConfigVars.values():
            if pred.get():
                npreds += 1
            if targ.get():
                ntargs += 1
        if ntargs == 0:
            tkMessageBox.showerror("Configuration Error", "No targets selected")
            return
        for (pred, _, _) in self.caseConfigVars.values():
            if pred.get():
                npreds += 1
        if npreds == 0:
            tkMessageBox.showerror("Configuration Error", "No predictors selected")
            return

        self.ok = True
        self.top.grab_release()
        self.top.quit()
        self.top.withdraw()

    def cancel_action(self):
        self.ok = False
        self.top.grab_release()
        self.top.quit()
        self.top.withdraw()

    def reset_action(self):
        self.batchSizeVar.set(value='20')
        self.numUnrollStepsVar.set(value='5')
        self.dropoutProbVar.set(value='0.2')
        self.numLayersVar.set(value='2')
        self.maxGradNormVar.set(value='5')
        self.initScaleVar.set(value='0.1')
        self.numEpochsFullLRVar.set(value='50')
        self.numEpochsVar.set(value='100')
        self.baseLearningRateVar.set(value='1.0')
        self.lrDecayVar.set(value='0.90')
        self.usePeepHolesVar.set(value=True)
        self.forgetBiasVar.set(value='0.1')
        self.numFoldsVar.set(value='10')
        self.rnnSizeVar.set(value='100')
        self.autoRNNSizeVar.set(value=True)
        self.validateEpochVar.set(value=False)
        self.optimizerVar.set(value="GradientDescent")
        self.optVar1.set(value="")
        self.optVar2.set(value="")
        self.optVar1Lbl.config(text="")
        self.optVar2Lbl.config(text="")
        self.optimizer_check()

    @staticmethod
    def categorical_check(var, box):
        if var.get():
            box.config(state=Tkinter.NORMAL)
        else:
            box.config(state=Tkinter.DISABLED)

    @staticmethod
    def inverse_categorical_check(var, box):
        if not var.get():
            box.config(state=Tkinter.NORMAL)
        else:
            box.config(state=Tkinter.DISABLED)

    def optimizer_check(self):
        if self.optimizerVar.get() == "GradientDescent":
            self.optVar1Lbl.config(text="")
            self.optVar2Lbl.config(text="")
            self.optVar1Entry.config(state=Tkinter.DISABLED)
            self.optVar2Entry.config(state=Tkinter.DISABLED)
        elif self.optimizerVar.get() == "Adadelta":
            self.optVar1Lbl.config(text="Rho (decay):")
            if self.optVar1.get() == "":
                self.optVar1.set("0.95")
            self.optVar1Entry.config(state=Tkinter.NORMAL)
            self.optVar2Lbl.config(text="")
            self.optVar2Entry.config(state=Tkinter.DISABLED)
        elif self.optimizerVar.get() == "Adagrad":
            self.optVar1Lbl.config(text="Initial Accumulator:")
            if self.optVar1.get() == "":
                self.optVar1.set("0.1")
            self.optVar1Entry.config(state=Tkinter.NORMAL)
            self.optVar2Lbl.config(text="")
            self.optVar2Entry.config(state=Tkinter.DISABLED)
        elif self.optimizerVar.get() == "Momentum":
            self.optVar1Lbl.config(text="Momentum:")
            if self.optVar1.get() == "":
                self.optVar1.set("0.1")
            self.optVar1Entry.config(state=Tkinter.NORMAL)
            self.optVar2Lbl.config(text="")
            self.optVar2Entry.config(state=Tkinter.DISABLED)
        elif self.optimizerVar.get() == "Adam":
            self.optVar1Lbl.config(text="Beta1:")
            if self.optVar1.get() == "":
                self.optVar1.set("0.9")
            self.optVar1Entry.config(state=Tkinter.NORMAL)
            self.optVar2Lbl.config(text="Beta2:")
            if self.optVar2.get() == "":
                self.optVar2.set("0.999")
            self.optVar2Entry.config(state=Tkinter.NORMAL)
        elif self.optimizerVar.get() == "FTRL":
            self.optVar1Lbl.config(text="Learning Rate Power:")
            if self.optVar1.get() == "":
                self.optVar1.set("-0.5")
            self.optVar1Entry.config(state=Tkinter.NORMAL)
            self.optVar2Lbl.config(text="Initial Accumulator:")
            if self.optVar2.get() == "":
                self.optVar2.set("0.1")
            self.optVar2Entry.config(state=Tkinter.NORMAL)
        elif self.optimizerVar.get() == "ProximalGradient":
            self.optVar1Lbl.config(text="L1 Reg Strength:")
            if self.optVar1.get() == "":
                self.optVar1.set("0.0")
            self.optVar1Entry.config(state=Tkinter.NORMAL)
            self.optVar2Lbl.config(text="L2 Reg Strength:")
            if self.optVar2.get() == "":
                self.optVar2.set("0.0")
            self.optVar2Entry.config(state=Tkinter.NORMAL)
        elif self.optimizerVar.get() == "ProximalAda":
            self.optVar1Lbl.config(text="Initial Accumulator:")
            if self.optVar1.get() == "":
                self.optVar1.set("0.1")
            self.optVar1Entry.config(state=Tkinter.NORMAL)
            self.optVar2Lbl.config(text="")
            self.optVar2Entry.config(state=Tkinter.DISABLED)
        elif self.optimizerVar.get() == "RMSpro":
            self.optVar1Lbl.config(text="Decay:")
            if self.optVar1.get() == "":
                self.optVar1.set("0.9")
            self.optVar1Entry.config(state=Tkinter.NORMAL)
            self.optVar2Lbl.config(text="Momentum:")
            if self.optVar2.get() == "":
                self.optVar2.set("0.0")
            self.optVar2Entry.config(state=Tkinter.NORMAL)

    def __init__(self, parser):
        p = 5

        self.ok = False

        self.attribs = parser.getAttributes()
        self.classifiers = parser.getClassifiers()
        self.xespath = parser.getXESPath()

        self.top = Tkinter.Toplevel()
        self.top.title("Configuration Dialog")
        self.top.resizable(False, False)
        self.top.protocol("WM_DELETE_WINDOW", on_closing)

        self.top.configure(padx=p, pady=p)
        self.top.resizable(False, False)

        intCheck = self.top.register(is_posint)
        floatCheck = self.top.register(is_number)

        xesPathFrame = Tkinter.LabelFrame(self.top, text="Log", padx=p, pady=p)
        xesPathFrame.grid(row=0, columnspan=8, sticky="WE")
        classifierFrame = Tkinter.LabelFrame(self.top, text="Multi-attribute Classifiers", padx=p, pady=p)
        classifierFrame.grid(row=1, columnspan=8, sticky="WE")
        evAttribsFrame = Tkinter.LabelFrame(self.top, text="Global Event Attributes", padx=p, pady=p)
        evAttribsFrame.grid(row=2, columnspan=8, sticky="WE")
        csAttribsFrame = Tkinter.LabelFrame(self.top, text="Global Case Attributes", padx=p, pady=p)
        csAttribsFrame.grid(row=3, columnspan=8, sticky="WE")
        confVarFrame = Tkinter.LabelFrame(self.top, text="Neural Net Configuration", padx=p, pady=p)
        confVarFrame.grid(row=4, columnspan=8, sticky="WE")
        optimizerFrame = Tkinter.LabelFrame(self.top, text="Optimizer", padx=p, pady=p)
        optimizerFrame.grid(row=5, columnspan=8, sticky="WE")

        lbl = Tkinter.Label(xesPathFrame, text="XES File:")
        lbl.grid(row=0, column=0)
        lbl = Tkinter.Label(xesPathFrame, text=self.xespath, padx=p, pady=p)
        lbl.grid(row=0, column=1, sticky=Tkinter.W)

        r = 0
        self.classifierConfigVars = dict()
        for (name, (attributes, possibleNumber)) in self.classifiers.iteritems():
            if possibleNumber:
                lbl = Tkinter.Label(classifierFrame, text="NUMERIC", background="red")
                lbl.grid(row=r, column=0)
            else:
                lbl = Tkinter.Label(classifierFrame, text="CATEGORICAL", background="yellow")
                lbl.grid(row=r, column=0)
            lbl = Tkinter.Label(classifierFrame, text=name + " = (" + ' '.join(attributes) + ")", padx=p, pady=p)
            lbl.grid(row=r, column=1, sticky=Tkinter.W)
            var1 = Tkinter.BooleanVar(value=False)
            chkbox = Tkinter.Checkbutton(classifierFrame, text="Predictor?", variable=var1, padx=p, pady=p)
            chkbox.grid(row=r, column=2, sticky=Tkinter.W)
            numvals = 1
            for k in attributes:
                (_, _, val2id, _) = self.attribs[(k, True)]
                numvals *= len(val2id)
            var3 = Tkinter.StringVar(value='{}'.format(int(math.ceil(math.sqrt(numvals/len(attributes))))))
            lbl = Tkinter.Label(classifierFrame, text="Embedding size:")
            lbl.grid(row=r, column=4, sticky=Tkinter.W)
            emSizeChkbox = Tkinter.Entry(classifierFrame, textvariable=var3, validate='key', validatecommand=(intCheck, '%P'))
            emSizeChkbox.grid(row=r, column=5, sticky=Tkinter.W)
            var2 = Tkinter.BooleanVar()
            chkbox = Tkinter.Checkbutton(classifierFrame, text="Treat as categorical?", variable=var2, padx=p, pady=p, command=lambda var=var2, box=emSizeChkbox: self.categorical_check(var, box))
            if not possibleNumber:
                var2.set(True)
                chkbox.config(state=Tkinter.DISABLED)
            chkbox.grid(row=r, column=3, sticky=Tkinter.W)
            var4 = Tkinter.BooleanVar(value=False)
            chkbox = Tkinter.Checkbutton(classifierFrame, text="Target?", variable=var4, padx=p, pady=p)
            chkbox.grid(row=r, column=6, sticky=Tkinter.W)
            self.classifierConfigVars[name] = (var1, var2, var3, var4)
            r += 1

        r = 0
        self.eventConfigVars = dict()
        for ((name, eventAttr), (_, attrtype, val2id, _)) in self.attribs.iteritems():
            if eventAttr:
                if attrtype == XESAttributeTypes.NUMERIC:
                    lbl = Tkinter.Label(evAttribsFrame, text="NUMERIC", background="red")
                    lbl.grid(row=r, column=0)
                elif attrtype == XESAttributeTypes.DATE:
                    lbl = Tkinter.Label(evAttribsFrame, text="DATE", background="green")
                    lbl.grid(row=r, column=0)
                elif attrtype == XESAttributeTypes.CATEGORICAL:
                    lbl = Tkinter.Label(evAttribsFrame, text="CATEGORICAL", background="yellow")
                    lbl.grid(row=r, column=0)
                lbl = Tkinter.Label(evAttribsFrame, text=name + " ({0} unique values)".format(len(val2id)), padx=p, pady=p)
                lbl.grid(row=r, column=1, sticky=Tkinter.W)
                var1 = Tkinter.BooleanVar(value=False)
                chkbox = Tkinter.Checkbutton(evAttribsFrame, text="Predictor?", variable=var1, padx=p, pady=p)
                chkbox.grid(row=r, column=2, sticky=Tkinter.W)
                lbl = Tkinter.Label(evAttribsFrame, text="Embedding size:")
                lbl.grid(row=r, column=4, sticky=Tkinter.W)
                var3 = Tkinter.StringVar(value='{}'.format(int(math.ceil(math.sqrt(len(val2id))))))
                emSizeChkbox = Tkinter.Entry(evAttribsFrame, textvariable=var3, validate='key', validatecommand=(intCheck, '%P'))
                emSizeChkbox.grid(row=r, column=5, sticky=Tkinter.W)
                if attrtype != XESAttributeTypes.CATEGORICAL:
                    emSizeChkbox.config(state=Tkinter.DISABLED)
                var2 = Tkinter.BooleanVar()
                chkbox = Tkinter.Checkbutton(evAttribsFrame, text="Treat as categorical?", variable=var2, padx=p, pady=p, command=lambda var=var2, box=emSizeChkbox: self.categorical_check(var, box))
                if attrtype != XESAttributeTypes.NUMERIC and attrtype != XESAttributeTypes.DATE:
                    var2.set(True)
                    chkbox.config(state=Tkinter.DISABLED)
                chkbox.grid(row=r, column=3, sticky=Tkinter.W)
                var4 = Tkinter.BooleanVar(value=False)
                chkbox = Tkinter.Checkbutton(evAttribsFrame, text="Target?", variable=var4, padx=p, pady=p)
                chkbox.grid(row=r, column=6, sticky=Tkinter.W)
                self.eventConfigVars[name] = (var1, var2, var3, var4)
                r += 1

        r = 0
        self.caseConfigVars = dict()
        for ((name, eventAttr), (_, attrtype, val2id, _)) in self.attribs.iteritems():
            if not eventAttr:
                if attrtype == XESAttributeTypes.NUMERIC:
                    lbl = Tkinter.Label(csAttribsFrame, text="NUMERIC", background="red")
                    lbl.grid(row=r, column=0)
                elif attrtype == XESAttributeTypes.DATE:
                    lbl = Tkinter.Label(csAttribsFrame, text="DATE", background="green")
                    lbl.grid(row=r, column=0)
                elif attrtype == XESAttributeTypes.CATEGORICAL:
                    lbl = Tkinter.Label(csAttribsFrame, text="CATEGORICAL", background="yellow")
                    lbl.grid(row=r, column=0)
                lbl = Tkinter.Label(csAttribsFrame, text=name + " ({0} unique values)".format(len(val2id)), padx=p, pady=p)
                lbl.grid(row=r, column=1, sticky=Tkinter.W)
                var1 = Tkinter.BooleanVar(value=False)
                chkbox = Tkinter.Checkbutton(csAttribsFrame, text="Predictor?", variable=var1, padx=p, pady=p)
                chkbox.grid(row=r, column=2, sticky=Tkinter.W)
                lbl = Tkinter.Label(csAttribsFrame, text="Embedding size:")
                lbl.grid(row=r, column=4, sticky=Tkinter.W)
                var3 = Tkinter.StringVar(value='{}'.format(int(math.ceil(math.sqrt(len(val2id))))))
                emSizeChkbox = Tkinter.Entry(csAttribsFrame, text="Embedding size", textvariable=var3, validate='key', validatecommand=(intCheck, '%P'))
                emSizeChkbox.grid(row=r, column=5, sticky=Tkinter.W)
                if attrtype != XESAttributeTypes.CATEGORICAL:
                    emSizeChkbox.config(state=Tkinter.DISABLED)
                var2 = Tkinter.BooleanVar()
                chkbox = Tkinter.Checkbutton(csAttribsFrame, text="Treat as categorical?", variable=var2, padx=p, pady=p, command=lambda var=var2, box=emSizeChkbox: self.categorical_check(var, box))
                if attrtype != XESAttributeTypes.NUMERIC and attrtype != XESAttributeTypes.DATE:
                    var2.set(True)
                    chkbox.config(state=Tkinter.DISABLED)
                chkbox.grid(row=r, column=3, sticky=Tkinter.W)
                self.caseConfigVars[name] = (var1, var2, var3)
                r += 1

        self.batchSizeVar = Tkinter.StringVar(value='20')
        lbl = Tkinter.Label(confVarFrame, text="Batch size:", padx=p, pady=p)
        lbl.grid(row=0, column=0, sticky=Tkinter.W)
        batchSizeEntry = Tkinter.Entry(confVarFrame, textvariable=self.batchSizeVar, validate='key', validatecommand=(intCheck, '%P'))
        batchSizeEntry.grid(row=0, column=1)

        self.numUnrollStepsVar = Tkinter.StringVar(value='5')
        lbl = Tkinter.Label(confVarFrame, text="Number of unroll steps:", padx=p, pady=p)
        lbl.grid(row=0, column=2, sticky=Tkinter.W)
        numUnrollStepsEntry = Tkinter.Entry(confVarFrame, textvariable=self.numUnrollStepsVar, validate='key', validatecommand=(intCheck, '%P'))
        numUnrollStepsEntry.grid(row=0, column=3)

        self.dropoutProbVar = Tkinter.StringVar(value='0.2')
        lbl = Tkinter.Label(confVarFrame, text="Dropout probability:", padx=p, pady=p)
        lbl.grid(row=0, column=4, sticky=Tkinter.W)
        dropoutProbEntry = Tkinter.Entry(confVarFrame, textvariable=self.dropoutProbVar, validate='key', validatecommand=(floatCheck, '%P'))
        dropoutProbEntry.grid(row=0, column=5)

        self.numLayersVar = Tkinter.StringVar(value='2')
        lbl = Tkinter.Label(confVarFrame, text="Number of layers", padx=p, pady=p)
        lbl.grid(row=1, column=0, sticky=Tkinter.W)
        numLayersEntry = Tkinter.Entry(confVarFrame, textvariable=self.numLayersVar, validate='key', validatecommand=(intCheck, '%P'))
        numLayersEntry.grid(row=1, column=1)

        self.maxGradNormVar = Tkinter.StringVar(value='5')
        lbl = Tkinter.Label(confVarFrame, text="Max gradient:", padx=p, pady=p)
        lbl.grid(row=1, column=2, sticky=Tkinter.W)
        maxGradNormEntry = Tkinter.Entry(confVarFrame, textvariable=self.maxGradNormVar, validate='key', validatecommand=(floatCheck, '%P'))
        maxGradNormEntry.grid(row=1, column=3)

        self.initScaleVar = Tkinter.StringVar(value='0.1')
        lbl = Tkinter.Label(confVarFrame, text="Init scale:", padx=p, pady=p)
        lbl.grid(row=1, column=4, sticky=Tkinter.W)
        initScaleEntry = Tkinter.Entry(confVarFrame, textvariable=self.initScaleVar, validate='key', validatecommand=(floatCheck, '%P'))
        initScaleEntry.grid(row=1, column=5)

        self.numEpochsFullLRVar = Tkinter.StringVar(value='50')
        lbl = Tkinter.Label(confVarFrame, text="Number of epochs w/ full learning rate:", padx=p, pady=p)
        lbl.grid(row=2, column=0, sticky=Tkinter.W)
        numEpochsFullLREntry = Tkinter.Entry(confVarFrame, textvariable=self.numEpochsFullLRVar, validate='key', validatecommand=(intCheck, '%P'))
        numEpochsFullLREntry.grid(row=2, column=1)

        self.numEpochsVar = Tkinter.StringVar(value='100')
        lbl = Tkinter.Label(confVarFrame, text="Number of epochs:", padx=p, pady=p)
        lbl.grid(row=2, column=2, sticky=Tkinter.W)
        numEpochsEntry = Tkinter.Entry(confVarFrame, textvariable=self.numEpochsVar, validate='key', validatecommand=(intCheck, '%P'))
        numEpochsEntry.grid(row=2, column=3)

        self.baseLearningRateVar = Tkinter.StringVar(value='1.0')
        lbl = Tkinter.Label(confVarFrame, text="Base learning rate:", padx=p, pady=p)
        lbl.grid(row=2, column=4, sticky=Tkinter.W)
        baseLearningRateEntry = Tkinter.Entry(confVarFrame, textvariable=self.baseLearningRateVar, validate='key', validatecommand=(floatCheck, '%P'))
        baseLearningRateEntry.grid(row=2, column=5)

        self.lrDecayVar = Tkinter.StringVar(value='0.90')
        lbl = Tkinter.Label(confVarFrame, text="Learning rate decay:", padx=p, pady=p)
        lbl.grid(row=3, column=0, sticky=Tkinter.W)
        lrDecayEntry = Tkinter.Entry(confVarFrame, textvariable=self.lrDecayVar, validate='key', validatecommand=(floatCheck, '%P'))
        lrDecayEntry.grid(row=3, column=1)

        self.usePeepHolesVar = Tkinter.BooleanVar(value=True)
        lbl = Tkinter.Label(confVarFrame, text="Use peepholes?", padx=p, pady=p)
        lbl.grid(row=3, column=2, sticky=Tkinter.E)
        usePeepHolesChkBox = Tkinter.Checkbutton(confVarFrame, variable=self.usePeepHolesVar)
        usePeepHolesChkBox.grid(row=3, column=3, sticky=Tkinter.W)

        self.forgetBiasVar = Tkinter.StringVar(value='0.1')
        lbl = Tkinter.Label(confVarFrame, text="Forget bias:", padx=p, pady=p)
        lbl.grid(row=3, column=4, sticky=Tkinter.W)
        forgetBiasEntry = Tkinter.Entry(confVarFrame, textvariable=self.forgetBiasVar, validate='key', validatecommand=(floatCheck, '%P'))
        forgetBiasEntry.grid(row=3, column=5)

        self.numFoldsVar = Tkinter.StringVar(value='10')
        lbl = Tkinter.Label(confVarFrame, text="Number of folds for crossvalidation:", padx=p, pady=p)
        lbl.grid(row=4, column=0, sticky=Tkinter.W)
        numFoldsEntry = Tkinter.Entry(confVarFrame, textvariable=self.numFoldsVar, validate='key', validatecommand=(intCheck, '%P'))
        numFoldsEntry.grid(row=4, column=1)

        self.rnnSizeVar = Tkinter.StringVar(value='100')
        lbl = Tkinter.Label(confVarFrame, text="RNN Size:", padx=p, pady=p)
        lbl.grid(row=4, column=4, sticky=Tkinter.W)
        rnnSizeEntry = Tkinter.Entry(confVarFrame, textvariable=self.rnnSizeVar, validate='key', validatecommand=(intCheck, '%P'))
        rnnSizeEntry.grid(row=4, column=5)
        rnnSizeEntry.config(state=Tkinter.DISABLED)

        self.autoRNNSizeVar = Tkinter.BooleanVar(value=True)
        lbl = Tkinter.Label(confVarFrame, text="Automatic RNN size?", padx=p, pady=p)
        lbl.grid(row=4, column=2, sticky=Tkinter.E)
        autoRNNSizeChkBox = Tkinter.Checkbutton(confVarFrame, variable=self.autoRNNSizeVar, command=lambda var=self.autoRNNSizeVar, box=rnnSizeEntry: self.inverse_categorical_check(var, box))
        autoRNNSizeChkBox.grid(row=4, column=3, sticky=Tkinter.W)

        self.validateEpochVar = Tkinter.BooleanVar(value=False)
        lbl = Tkinter.Label(confVarFrame, text="Validate each epoch?", padx=p, pady=p)
        lbl.grid(row=5, column=0, sticky=Tkinter.E)
        validateEpochChkBox = Tkinter.Checkbutton(confVarFrame, variable=self.validateEpochVar)
        validateEpochChkBox.grid(row=5, column=1, sticky=Tkinter.W)

        self.sharedRNNVar = Tkinter.BooleanVar(value=False)
        lbl = Tkinter.Label(confVarFrame, text="Shared RNN?", padx=p, pady=p)
        lbl.grid(row=5, column=2, sticky=Tkinter.E)
        sharedRNNChkBox = Tkinter.Checkbutton(confVarFrame, variable=self.sharedRNNVar)
        sharedRNNChkBox.grid(row=5, column=3, sticky=Tkinter.W)

        self.rnnActivationVar = Tkinter.StringVar(value="tanh")
        lbl = Tkinter.Label(confVarFrame, text="RNN Acviation Func", padx=p, pady=p)
        lbl.grid(row=5, column=4, sticky=Tkinter.E)
        rnnActivationBox = ttk.Combobox(confVarFrame, textvariable=self.rnnActivationVar, state='readonly')
        rnnActivationBox['value'] = ('tanh', 'sigmoid')
        rnnActivationBox.grid(row=5, column=5)

        self.dateScaleVar = Tkinter.StringVar(value="Standardize")
        lbl = Tkinter.Label(confVarFrame, text="Scale Datetime to:", padx=p, pady=p)
        lbl.grid(row=6, column=0, sticky=Tkinter.E)
        scaleBox = ttk.Combobox(confVarFrame, textvariable=self.dateScaleVar, state='readonly')
        scaleBox['values'] = ('Standardize', 'Days', 'Hours', 'Minutes', 'Seconds')
        scaleBox.grid(row=6, column=1, sticky=Tkinter.W)

        self.numLossVar = Tkinter.StringVar(value="MSE")
        lbl = Tkinter.Label(confVarFrame, text="Loss function for numeric:", padx=p, pady=p)
        lbl.grid(row=6, column=2, sticky=Tkinter.E)
        lossBox = ttk.Combobox(confVarFrame, textvariable=self.numLossVar, state='readonly')
        lossBox['values'] = ('MSE', 'MAE', 'RMSE')
        lossBox.grid(row=6, column=3, sticky=Tkinter.W)

        self.optimizerVar = Tkinter.StringVar(value="GradientDescent")
        lbl = Tkinter.Radiobutton(optimizerFrame, text="Gradient Descent", variable=self.optimizerVar, value="GradientDescent", padx=p, pady=p, command=self.optimizer_check)
        lbl.grid(row=0, column=0, sticky=Tkinter.W)
        lbl = Tkinter.Radiobutton(optimizerFrame, text="Ada Delta", variable=self.optimizerVar, value="Adadelta", padx=p, pady=p, command=self.optimizer_check)
        lbl.grid(row=0, column=1, sticky=Tkinter.W)
        lbl = Tkinter.Radiobutton(optimizerFrame, text="Ada Grad", variable=self.optimizerVar, value="Adagrad", padx=p, pady=p, command=self.optimizer_check)
        lbl.grid(row=0, column=2, sticky=Tkinter.W)
        lbl = Tkinter.Radiobutton(optimizerFrame, text="Momentum", variable=self.optimizerVar, value="Momentum", padx=p, pady=p, command=self.optimizer_check)
        lbl.grid(row=0, column=3, sticky=Tkinter.W)
        lbl = Tkinter.Radiobutton(optimizerFrame, text="Adam", variable=self.optimizerVar, value="Adam", padx=p, pady=p, command=self.optimizer_check)
        lbl.grid(row=0, column=4, sticky=Tkinter.W)
        lbl = Tkinter.Radiobutton(optimizerFrame, text="FTRL", variable=self.optimizerVar, value="FTRL", padx=p, pady=p, command=self.optimizer_check)
        lbl.grid(row=1, column=0, sticky=Tkinter.W)
        lbl = Tkinter.Radiobutton(optimizerFrame, text="Proximal Gradient", variable=self.optimizerVar, value="ProximalGradient", padx=p, pady=p, command=self.optimizer_check)
        lbl.grid(row=1, column=1, sticky=Tkinter.W)
        lbl = Tkinter.Radiobutton(optimizerFrame, text="Proximal Ada", variable=self.optimizerVar, value="ProximalAda", padx=p, pady=p, command=self.optimizer_check)
        lbl.grid(row=1, column=2, sticky=Tkinter.W)
        lbl = Tkinter.Radiobutton(optimizerFrame, text="RMS Pro", variable=self.optimizerVar, value="RMSpro", padx=p, pady=p, command=self.optimizer_check)
        lbl.grid(row=1, column=3, sticky=Tkinter.W)

        self.optVar1 = Tkinter.StringVar(value="")
        self.optVar2 = Tkinter.StringVar(value="")
        self.optVar1Lbl = Tkinter.Label(optimizerFrame, text="", padx=p, pady=p)
        self.optVar1Lbl.grid(row=2, column=0, sticky=Tkinter.E)
        self.optVar2Lbl = Tkinter.Label(optimizerFrame, text="", padx=p, pady=p)
        self.optVar2Lbl.grid(row=2, column=2, sticky=Tkinter.E)
        self.optVar1Entry = Tkinter.Entry(optimizerFrame, textvariable=self.optVar1, validate='key', validatecommand=(floatCheck, '%P'), state=Tkinter.DISABLED)
        self.optVar1Entry.grid(row=2, column=1, sticky=Tkinter.E)
        self.optVar2Entry = Tkinter.Entry(optimizerFrame, textvariable=self.optVar2, validate='key', validatecommand=(floatCheck, '%P'), state=Tkinter.DISABLED)
        self.optVar2Entry.grid(row=2, column=3, sticky=Tkinter.E)

        btn = Tkinter.Button(self.top, text="OK", padx=p, pady=p)
        btn['command'] = self.go_action
        btn.grid(row=6, column=0, columnspan=2)

        btn = Tkinter.Button(self.top, text="Cancel", padx=p, pady=p)
        btn['command'] = self.cancel_action
        btn.grid(row=6, column=2, columnspan=2)

        btn = Tkinter.Button(self.top, text="Reset NN Params", padx=p, pady=p)
        btn['command'] = self.reset_action
        btn.grid(row=6, column=4, columnspan=2)

        btn = Tkinter.Button(self.top, text="Load Config", padx=p, pady=p)
        btn['command'] = self.load
        btn.grid(row=6, column=6, columnspan=2)

    def show(self):
        self.top.grab_set()
        self.top.mainloop()

    def load(self):
        configPath = tkFileDialog.askopenfilename(title="Select configuration file", filetypes=[('Python Pickle Files', '*.pickle')])
        if len(configPath) != 0:
            try:
                pickleFile = open(configPath, "r")
                graphConfig = pickle.load(pickleFile)
                self.set(graphConfig)
            except pickle.UnpicklingError:
                pass

    def set(self, config):
        self.batchSizeVar.set(config.batchSize)
        self.numUnrollStepsVar.set(config.numUnrollSteps)
        self.dropoutProbVar.set(config.dropoutProb)
        self.numLayersVar.set(config.numLayers)
        self.maxGradNormVar.set(config.maxGradNorm)
        self.initScaleVar.set(config.initScale)
        self.numEpochsFullLRVar.set(config.numEpochsFullLR)
        self.numEpochsVar.set(config.numEpochs)
        self.baseLearningRateVar.set(config.baseLearningRate)
        self.lrDecayVar.set(config.lrDecay)
        self.usePeepHolesVar.set(config.usePeepHoles)
        self.forgetBiasVar.set(config.forgetBias)
        self.numFoldsVar.set(config.numFolds)
        self.autoRNNSizeVar.set(config.autoRNNSize)
        self.rnnSizeVar.set(config.rnnSize)
        self.validateEpochVar.set(config.validateEpoch)
        self.sharedRNNVar.set(config.sharedRNN)
        self.dateScaleVar.set(config.dateScale)
        self.numLossVar.set(config.numLossFunc)
        self.rnnActivationVar.set(config.rnnActivationFunc)
        self.optimizerVar.set(config.optimizer)
        self.optVar1.set(config.optimPar1)
        self.optVar2.set(config.optimPar2)
        # call the optimizer check function to enable/disable the right labels/entry boxes
        self.optimizer_check()

        for (attrnames, attrtype, embedSize, val2id, id2val, rowNum, eventAttrib) in config.preds:
            if eventAttrib and len(attrnames) > 1:
                # This refers to a classifier, let's find the classifier name(s)
                classifierNames = [clsName for (clsName, (names, _)) in self.classifiers.iteritems() if names == attrnames]
                for clsName in classifierNames:
                    if clsName in self.classifierConfigVars:
                        (predictorVar, categoricalVar, embedSizeVar, targetVar) = self.classifierConfigVars[clsName]
                        predictorVar.set((names, True) in config.predictors)
                        categoricalVar.set(attrtype == XESAttributeTypes.CATEGORICAL)
                        embedSizeVar.set(embedSize)
                        targetVar.set((names, True) in config.targets)
            else:
                if eventAttrib and attrnames[0] in self.eventConfigVars:
                    (predictorVar, categoricalVar, embedSizeVar, targetVar) = self.eventConfigVars[attrnames[0]]
                    predictorVar.set((attrnames, True) in config.predictors)
                    categoricalVar.set(attrtype == XESAttributeTypes.CATEGORICAL)
                    embedSizeVar.set(embedSize)
                    targetVar.set((attrnames, True) in config.targets)
                if not eventAttrib and attrnames[0] in self.caseConfigVars:
                    (predictorVar, categoricalVar, embedSizeVar) = self.caseConfigVars[attrnames[0]]
                    predictorVar.set((attrnames, False) in config.predictors)
                    categoricalVar.set(attrtype == XESAttributeTypes.CATEGORICAL)
                    embedSizeVar.set(embedSize)


class TFGraphConfig:

    def __init__(self, dlg, logDirPath):
        self.preds = []
        self.targets = []
        self.predictors = []
        self.scale = dict()

        for (name, (attrnames, attrtype)) in dlg.classifiers.iteritems():
            (pred, cat, embedSize, targ) = dlg.classifierConfigVars[name]
            if targ.get():
                self.targets.append((attrnames, True))
            if pred.get():
                self.predictors.append((attrnames, True))

            if cat.get() or attrtype == XESAttributeTypes.CATEGORICAL:
                attrtype = XESAttributeTypes.CATEGORICAL
            if not cat.get() and attrtype == XESAttributeTypes.NUMERIC:
                attrtype = XESAttributeTypes.NUMERIC
            self.preds.append((attrnames, attrtype, int(embedSize.get()), None, None, -1, True))

        for ((name, eventAttrib), (rowNum, attrtype, val2id, id2val)) in dlg.attribs.iteritems():
            if eventAttrib:
                (pred, cat, embedSize, targ) = dlg.eventConfigVars[name]
                if targ.get():
                    self.targets.append(([name], True))
            else:
                (pred, cat, embedSize) = dlg.caseConfigVars[name]
            if pred.get():
                self.predictors.append(([name], eventAttrib))

            if cat.get() or attrtype == XESAttributeTypes.CATEGORICAL:
                attrtype = XESAttributeTypes.CATEGORICAL
            if not cat.get() and attrtype == XESAttributeTypes.NUMERIC:
                attrtype = XESAttributeTypes.NUMERIC
            if not cat.get() and attrtype == XESAttributeTypes.DATE:
                attrtype = XESAttributeTypes.DATE
            self.preds.append(([name], attrtype, int(embedSize.get()), val2id, id2val, rowNum, eventAttrib))

        self.batchSize = int(dlg.batchSizeVar.get())
        self.numUnrollSteps = int(dlg.numUnrollStepsVar.get())
        self.dropoutProb = float(dlg.dropoutProbVar.get())
        self.numLayers = int(dlg.numLayersVar.get())
        self.maxGradNorm = float(dlg.maxGradNormVar.get())
        self.initScale = float(dlg.initScaleVar.get())
        self.numEpochsFullLR = int(dlg.numEpochsFullLRVar.get())
        self.numEpochs = int(dlg.numEpochsVar.get())
        self.baseLearningRate = float(dlg.baseLearningRateVar.get())
        self.lrDecay = float(dlg.lrDecayVar.get())
        self.usePeepHoles = bool(dlg.usePeepHolesVar.get())
        self.forgetBias = float(dlg.forgetBiasVar.get())
        self.numFolds = int(dlg.numFoldsVar.get())
        self.autoRNNSize = bool(dlg.autoRNNSizeVar.get())
        self.rnnSize = int(dlg.rnnSizeVar.get())
        self.validateEpoch = bool(dlg.validateEpochVar.get())
        self.sharedRNN = bool(dlg.sharedRNNVar.get())
        self.dateScale = dlg.dateScaleVar.get()
        self.numLossFunc = dlg.numLossVar.get()
        self.rnnActivationFunc = dlg.rnnActivationVar.get()
        self.optimizer = dlg.optimizerVar.get()
        self.optimPar1 = float(dlg.optVar1.get()) if dlg.optVar1.get() != "" else 0.0
        self.optimPar2 = float(dlg.optVar2.get()) if dlg.optVar2.get() != "" else 0.0

        self.xesPath = dlg.xespath
        self.logDirPath = logDirPath
