import glob
import pandas as pd
import os
import numpy as np
import pickle

dirs = glob.glob("*/")
#dirs = ['DES-SINGLE-RV/']

thingdict = {
    "DES-SINGLE-RV/":"ORVOEBV",
    "DES-TWO-RV/":'TRVOEBV',
    "DES-TWO-RV-TWO-EBV-MULTI-BETA/":"TRVTEBVMB",
    "DES-TWO-RV-TWO-EBV/":"TRVTEBV",
    "DES-INTRINSIC-DIFF/":"IGN",
    "DES-NO-CINT_z0.4/":"IGN",
    "DES-NO-CINT_sig0/":"IGN",
    "DES-NO-CINT-NO-MWEBV/":"IGN",
    "DES-NO-CINT-NO-MWEBV_sig0/":"IGN",
    "DES-NO-CINT/":"IGN",
    "CC/training_mass_split/PIa_-99":"CCNOCUT",
    "CC/training_mass_split/PIa_0.5":"CCSOMECUT",
    "CC/training_mass_split/PIa_0.9":"CCMORECUT"
}


HMRVM = [] ; LMRVM = [] ; HMEBV = [] ; LMEBV = [] ; HMRVS = [] ; LMRVS = []

#[HMRVM, LMRVM, HMEBV, LMEBV, HMRVS, LMRVS]

def write_out(vals, Ns, thingdict, directory):
    if thingdict[directory] == "IGN":
        print(f"This directory {directory} has been set to ignore, skipping")
        return
    prefixes = ["HMRVM", "LMRVM", "HMEBV", "LMEBV", "HMRVS", "LMRVS"]
    with open(f"LATEX-OUT-MG.tex", "a+") as f:
        for n,blep in enumerate(vals):
            #blep will be a list of tuples, need to separate it into two
            try:
                means, errs = zip(*blep)
            except ValueError:
                print(f"This directory, {directory}, is otherwise empty (perhaps this is not yet coded for the mass split results?). Skipping for now.")
                return
            print(np.mean(means), np.std(means))
            print(np.std(means), np.mean(errs))
            val = np.around(np.average(means, weights=1/np.power(errs, 2)), 3)
            valstderr = np.around(np.std(means) / np.sqrt(len(means)), 3)
            valstd = np.around(np.std(means), 3)
            valerr = np.around(np.mean(errs), 3)
            if val < 0.2:
                f.write(f"{chr(92)}def{chr(92)}{thingdict[directory]}{prefixes[n]}{{${val:0.3f}({valstderr:0.3f})$}}")
                f.write("\n")
                f.write(f"{chr(92)}def{chr(92)}{thingdict[directory]}{prefixes[n]}std{{${valstd:0.3f}$}}")
                f.write("\n")
                f.write(f"{chr(92)}def{chr(92)}{thingdict[directory]}{prefixes[n]}err{{${valerr:0.3f}$}}")
            else:
                f.write(f"{chr(92)}def{chr(92)}{thingdict[directory]}{prefixes[n]}{{${val:0.2f}({valstderr:0.2f})$}}")
                f.write("\n")
                f.write(f"{chr(92)}def{chr(92)}{thingdict[directory]}{prefixes[n]}std{{${valstd:0.2f}$}}")
                f.write("\n")
                f.write(f"{chr(92)}def{chr(92)}{thingdict[directory]}{prefixes[n]}err{{${valerr:0.2f}$}}")
            f.write("\n")
        for n,blep in enumerate(Ns):
            f.write(f"{chr(92)}def{chr(92)}{thingdict[directory]}{prefixes[n]}N{{{blep}({2*blep}\%)}}")
            f.write("\n")
    return print("Done")
try:
    print("Removing existing output file")
    os.remove("LATEX-OUT.tex")
except FileNotFoundError:
    pass
    
dirs = thingdict.keys()

truths = [
        [3, 0.5, 0.08, 3, 0.5, 0.08],
        [3, 0.5, 0.08, 2, 0.5, 0.08],
        [3, 0.5, 0.125, 2, 0.5, 0.08],
        [3, 0.5, 0.125, 2, 0.5, 0.08],
        [3, 0.5, 0.125, 2, 0.5, 0.08],
        [],
        [],
        [],
        [],
        [],
        [3,  0.5, 0.125, 2, 0.5, 0.08],
        [3,  0.5, 0.125, 2, 0.5, 0.08],
        [3,  0.5, 0.125, 2, 0.5, 0.08]
        ]

for i, d in enumerate(dirs):
    files = glob.glob(f"{d}*training_popRV/*/chains.pkl")
    if len(files) == 0:
        files = glob.glob(f"{d}*training_mass_split/*/chains.pkl") 
    if len(files) == 0:
        files = glob.glob(f"{d}/*/chains.pkl")
    print(f'Found {len(files)} finished files in {d}')
    truth  = truths[i]
    HMRVM = [] ; LMRVM = [] ; HMEBV = [] ; LMEBV = [] ; HMRVS = [] ; LMRVS = []
    eHMRVM = [] ; eLMRVM = [] ; eHMEBV = [] ; eLMEBV = [] ; eHMRVS = [] ; eLMRVS = []
    N95_HMRVM, N95_LMRVM, N95_HMEBV, N95_LMEBV, N95_HMRVS, N95_LMRVS = 0, 0, 0, 0, 0, 0
    for f in files:
        with open(f, "rb") as input_file:
            e = pickle.load(input_file)
        #do some processing of the data according to Mattster
        try:
            HMRVM.append([np.around(np.mean(e['mu_R']), 6   ), np.around(np.std(e['mu_R']), 6) ])
            LMRVM.append([np.around(np.mean(e['mu_R']), 6   ), np.around(np.std(e['mu_R']), 6) ])
            HMEBV.append([np.around(np.mean(e['tauA']/e['mu_R']), 6   ), np.around(np.std(e['tauA']/e['mu_R']), 6) ])
            LMEBV.append([np.around(np.mean(e['tauA']/e['mu_R']), 6   ), np.around(np.std(e['tauA']/e['mu_R']), 6) ])
            HMRVS.append([np.around(np.mean(e['sigma_R']),6 ), np.around(np.std(e['sigma_R']), 6) ])
            LMRVS.append([np.around(np.mean(e['sigma_R']),3 ), np.around(np.std(e['sigma_R']), 6) ])
            if np.abs((e['mu_R'].mean() - truth[0]) / e['mu_R'].std()) < 2:
                N95_HMRVM += 1
                N95_LMRVM += 1
            RVS_l95, RVS_u95 = np.quantile(e['sigma_R'], [0.025, 0.975])
            if truth[1] > RVS_l95 and truth[1] < RVS_u95:
                N95_HMRVS += 1
                N95_LMRVS += 1
            if np.abs((np.mean(e['tauA']/e['mu_R']) - truth[2]) / np.std(e['tauA']/e['mu_R'])) < 2:
                N95_HMEBV += 1
                N95_LMEBV += 1
        except KeyError:
            HMRVM.append([np.around(np.mean(e['mu_R_HM']), 6   ), np.around(np.std(e['mu_R_HM']), 6) ])
            LMRVM.append([np.around(np.mean(e['mu_R_LM']), 6   ), np.around(np.std(e['mu_R_LM']), 6) ])
            HMEBV.append([np.around(np.mean(e['tauA_HM']/e['mu_R_HM']), 6   ), np.around(np.std(e['tauA_HM']/e['mu_R_HM']), 6) ])
            LMEBV.append([np.around(np.mean(e['tauA_LM']/e['mu_R_LM']), 6   ), np.around(np.std(e['tauA_LM']/e['mu_R_LM']), 6) ])
            HMRVS.append([np.around(np.mean(e['sigma_R_HM']),6 ), np.around(np.std(e['sigma_R_HM']), 6) ])
            LMRVS.append([np.around(np.mean(e['sigma_R_LM']),6 ), np.around(np.std(e['sigma_R_LM']), 6) ])
            if np.abs((e['mu_R_HM'].mean() - truth[0]) / e['mu_R_HM'].std()) < 2:
                N95_HMRVM += 1
            RVS_l95, RVS_u95 = np.quantile(e['sigma_R_HM'], [0.025, 0.975])
            if truth[1] > RVS_l95 and truth[1] < RVS_u95:
                N95_HMRVS += 1
            if np.abs((np.mean(e['tauA_HM']/e['mu_R_HM']) - truth[2]) / np.std(e['tauA_HM']/e['mu_R_HM'])) < 2:
                N95_HMEBV += 1
            if np.abs((e['mu_R_LM'].mean() - truth[3]) / e['mu_R_LM'].std()) < 2:
                N95_LMRVM += 1
            RVS_l95, RVS_u95 = np.quantile(e['sigma_R_LM'], [0.025, 0.975])
            if truth[4] > RVS_l95 and truth[4] < RVS_u95:
                N95_LMRVS += 1
            if np.abs((np.mean(e['tauA_LM']/e['mu_R_LM']) - truth[5]) / np.std(e['tauA_LM']/e['mu_R_LM'])) < 2:
                N95_LMEBV += 1
    write_out([HMRVM, LMRVM, HMEBV, LMEBV, HMRVS, LMRVS], [N95_HMRVM, N95_HMRVM, N95_HMEBV, N95_LMEBV, N95_HMRVS, N95_LMRVS], thingdict, d)
