import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sncosmo
from astropy.io import fits
from tqdm import tqdm

z_lim = 0.36
pia_cut = 0.9

all_mu_R_HMs, all_mu_R_LMs, all_sigma_R_HMs, all_sigma_R_LMs, all_tauE_HMs, all_tauE_LMs = [], [], [], [], [], []
all_mu_R_HMs_SNN, all_mu_R_LMs_SNN, all_sigma_R_HMs_SNN, all_sigma_R_LMs_SNN, all_tauE_HMs_SNN, all_tauE_LMs_SNN = [], [], [], [], [], []

all_HM_EBVs, all_LM_EBVs = [], []

for i in tqdm(np.arange(1, 51)):
    RV_HMs, RV_LMs, EBV_HMs, EBV_LMs = [], [], [], []
    RV_HMs_SNN, RV_LMs_SNN, EBV_HMs_SNN, EBV_LMs_SNN = [], [], [], []
    masses = []
    N_HM_cc, N_LM_cc = 0, 0
    data_dir = f'/pscratch/sd/d/desctd/PIPPIN_OUTPUT/BP-DES-W-CC/1_SIM/SIMDATADES_P21/PIP_BP-DES-W-CC_SIMDATADES_P21-{str(i).zfill(4)}'
    PIa_file = pd.read_csv(f'/global/cfs/cdirs/lsst/www/DESC_TD_PUBLIC/users/mgrayling/BP-DESWCC-CLASS/3_CLAS/SNN_DES_{i}_SIMDATADES_P21_SNNTRAINV19_z_TRAINDES_V19/predictions.csv')
    sn_list = np.loadtxt(os.path.join(data_dir, f'PIP_BP-DES-W-CC_SIMDATADES_P21-{str(i).zfill(4)}.LIST'), dtype=str)
    head_file = os.path.join(data_dir, f'{sn_list[0]}')
    if not os.path.exists(head_file):
        head_file = os.path.join(data_dir, f'{sn_list[0]}.gz')  # Look for .fits.gz if .fits not found
    phot_file = head_file.replace("HEAD", "PHOT")
    sne_file = sncosmo.read_snana_fits(head_file, phot_file)
    sim_RVs, sim_ebvs = [], []
    sim_RVs_SNN, sim_ebvs_SNN = [], []
    # Check if sim or real data
    for sn_file in sn_list:
        head_file = os.path.join(data_dir, f'{sn_file}')
        if not os.path.exists(head_file):
            head_file = os.path.join(data_dir, f'{sn_file}.gz')  # Look for .fits.gz if .fits not found
        phot_file = head_file.replace("HEAD", "PHOT")
        sne_file = sncosmo.read_snana_fits(head_file, phot_file)
        for sn_ind in range(len(sne_file)):
            sn = sne_file[sn_ind]
            meta, data = sn.meta, sn.to_pandas()
            sn_name = meta['SNID']
            if isinstance(sn_name, bytes):
                sn_name = sn_name.decode('utf-8')
            pia = PIa_file[PIa_file.SNID == int(sn_name)].PROB_SNNTRAINV19_z_TRAINDES_V19.values[0]
            zhel = meta['REDSHIFT_HELIO']
            zcmb = meta['REDSHIFT_FINAL']
            if zhel > z_lim:
                continue
            if meta['SIM_GENTYPE'] != 1:
                continue
            if meta['SIM_GENTYPE'] == 1:
                masses.append(meta['HOSTGAL_LOGMASS'])
            if meta['SIM_GENTYPE'] > 1 and meta['HOSTGAL_LOGMASS'] > 10.1:
                N_HM_cc += 1
            elif meta['SIM_GENTYPE'] > 1 and meta['HOSTGAL_LOGMASS'] < 9.9:
                N_LM_cc += 1
            if meta['SIM_GENTYPE'] == 1 and meta['HOSTGAL_LOGMASS'] > 10.1:
                RV_HMs.append(meta['SIM_RV'])
                EBV_HMs.append(meta['SIM_AV'] / meta['SIM_RV'])
                all_HM_EBVs.append(meta['SIM_AV'] / meta['SIM_RV'])
            elif meta['SIM_GENTYPE'] == 1 and meta['HOSTGAL_LOGMASS'] < 9.9:
                RV_LMs.append(meta['SIM_RV'])
                EBV_LMs.append(meta['SIM_AV'] / meta['SIM_RV'])
                all_LM_EBVs.append(meta['SIM_AV'] / meta['SIM_RV'])
            if meta['SIM_GENTYPE'] == 1 and meta['HOSTGAL_LOGMASS'] > 10.1 and pia > pia_cut:
                RV_HMs_SNN.append(meta['SIM_RV'])
                EBV_HMs_SNN.append(meta['SIM_AV'] / meta['SIM_RV'])
            elif meta['SIM_GENTYPE'] == 1 and meta['HOSTGAL_LOGMASS'] < 9.9 and pia > pia_cut:
                RV_LMs_SNN.append(meta['SIM_RV'])
                EBV_LMs_SNN.append(meta['SIM_AV'] / meta['SIM_RV'])
    all_mu_R_HMs.append(np.mean(RV_HMs))
    all_mu_R_LMs.append(np.mean(RV_LMs))
    all_sigma_R_HMs.append(np.std(RV_HMs))
    all_sigma_R_LMs.append(np.std(RV_LMs))
    all_tauE_HMs.append(np.mean(EBV_HMs))
    all_tauE_LMs.append(np.mean(EBV_LMs))
    all_mu_R_HMs_SNN.append(np.mean(RV_HMs_SNN))
    all_mu_R_LMs_SNN.append(np.mean(RV_LMs_SNN))
    all_sigma_R_HMs_SNN.append(np.std(RV_HMs_SNN))
    all_sigma_R_LMs_SNN.append(np.std(RV_LMs_SNN))
    all_tauE_HMs_SNN.append(np.mean(EBV_HMs_SNN))
    all_tauE_LMs_SNN.append(np.mean(EBV_LMs_SNN))

    bins = np.arange(0, 0.5, 0.05)
    # plt.hist(masses)
    # plt.show()
    # plt.hist(EBV_HMs, bins=bins)
    # plt.show()
    # plt.hist(EBV_LMs, bins=bins)
    # plt.show()
    print(i, f'N_HM: {len(RV_HMs)}, {N_HM_cc}, {len(RV_HMs) - N_HM_cc}', f'N_LM: {len(RV_LMs)}, {N_LM_cc}, {len(RV_LMs) - N_LM_cc}', len(RV_HMs), len(RV_HMs_SNN), np.mean(RV_HMs), np.mean(RV_HMs_SNN), np.std(RV_HMs), np.mean(RV_LMs), np.std(RV_LMs_SNN), np.std(RV_LMs), np.std(RV_LMs_SNN), np.mean(EBV_HMs), np.mean(EBV_HMs_SNN), np.mean(EBV_LMs), np.mean(EBV_LMs_SNN))

plt.hist(all_HM_EBVs)
print(np.mean(all_HM_EBVs))
plt.show()
plt.hist(all_LM_EBVs)
print(np.mean(all_LM_EBVs))
plt.show()

all_mu_R_HMs, all_mu_R_LMs, all_sigma_R_HMs, all_sigma_R_LMs, all_tauE_HMs, all_tauE_LMs = np.array(all_mu_R_HMs), np.array(all_mu_R_LMs), np.array(all_sigma_R_HMs), np.array(all_sigma_R_LMs), np.array(all_tauE_HMs), np.array(all_tauE_LMs)
all_mu_R_HMs_SNN, all_mu_R_LMs_SNN, all_sigma_R_HMs_SNN, all_sigma_R_LMs_SNN, all_tauE_HMs_SNN, all_tauE_LMs_SNN = np.array(all_mu_R_HMs_SNN), np.array(all_mu_R_LMs_SNN), np.array(all_sigma_R_HMs_SNN), np.array(all_sigma_R_LMs_SNN), np.array(all_tauE_HMs_SNN), np.array(all_tauE_LMs_SNN)
print(all_mu_R_HMs.mean(), all_mu_R_HMs.std() / np.sqrt(len(all_mu_R_HMs)))
print(all_mu_R_LMs.mean(), all_mu_R_LMs.std() / np.sqrt(len(all_mu_R_HMs)))
print(all_sigma_R_HMs.mean(), all_sigma_R_HMs.std() / np.sqrt(len(all_mu_R_HMs)))
print(all_sigma_R_LMs.mean(), all_sigma_R_LMs.std() / np.sqrt(len(all_mu_R_HMs)))
print(all_tauE_HMs.mean(), all_tauE_HMs.std() / np.sqrt(len(all_mu_R_HMs)))
print(all_tauE_LMs.mean(), all_tauE_LMs.std() / np.sqrt(len(all_mu_R_HMs)))
print('-----')
print(all_mu_R_HMs_SNN.mean(), all_mu_R_HMs_SNN.std() / np.sqrt(len(all_mu_R_HMs_SNN)))
print(all_mu_R_LMs_SNN.mean(), all_mu_R_LMs_SNN.std() / np.sqrt(len(all_mu_R_HMs_SNN)))
print(all_sigma_R_HMs_SNN.mean(), all_sigma_R_HMs_SNN.std() / np.sqrt(len(all_mu_R_HMs_SNN)))
print(all_sigma_R_LMs_SNN.mean(), all_sigma_R_LMs_SNN.std() / np.sqrt(len(all_mu_R_HMs_SNN)))
print(all_tauE_HMs_SNN.mean(), all_tauE_HMs_SNN.std() / np.sqrt(len(all_mu_R_HMs_SNN)))
print(all_tauE_LMs_SNN.mean(), all_tauE_LMs_SNN.std() / np.sqrt(len(all_mu_R_HMs_SNN)))
