"""
"""
import os
import numpy as np
import h5py
from astropy.cosmology import Planck15


TASSO = "/Users/aphearin/work/DATA/SIMS/TNG_parametric_scatter"
BEBOP = "/lcrc/project/halotools/TNG300/parametric_scatter"
bn_halos = "TNG300_L1_Snap099_z0p00.hdf5"
BN_MAHS = "TNG300_L1_tracing.hdf5"
LOG_MAH_FIT_MIN = 10.0


def load_tng_mah_data(data_drn=TASSO, bn_mahs=BN_MAHS, log_mah_fit_min=LOG_MAH_FIT_MIN):
    mah_def_fitter_keys = ["M500crit_Msun", "M200mean_Msun", "Mvir_Msun"]

    with h5py.File(os.path.join(data_drn, bn_mahs), "r") as hdf:
        tng_z = hdf["Redshifts"][...][::-1]
        tng_t = Planck15.age(tng_z).value

        all_halo_mah_keys = [s for s in hdf.keys() if s[:5] == "halo_"]
        halo_ids = np.array(
            [int(halo_tree_id.split("_")[1]) for halo_tree_id in all_halo_mah_keys]
        )
        mah_data = []
        for mah_def_key in mah_def_fitter_keys:
            _mah = [hdf[key][mah_def_key][...][::-1] for key in all_halo_mah_keys]
            mah_data.append(np.array(_mah))

    # Compute peak mass
    mah_500c, mah_200m, mah_mvir = mah_data
    mah_500c = np.maximum.accumulate(mah_500c, axis=1)
    mah_200m = np.maximum.accumulate(mah_200m, axis=1)
    mah_mvir = np.maximum.accumulate(mah_mvir, axis=1)

    # Computed log peak mass and pad the zeros
    pad = -99.0
    mah_500c = np.where(mah_500c == 0, pad, mah_500c)
    mah_200m = np.where(mah_200m == 0, pad, mah_200m)
    mah_mvir = np.where(mah_mvir == 0, pad, mah_mvir)
    log_mah_500c = np.log10(mah_500c)
    log_mah_200m = np.log10(mah_200m)
    log_mah_mvir = np.log10(mah_mvir)

    return halo_ids, log_mah_500c, log_mah_200m, log_mah_mvir, tng_t, log_mah_fit_min
