"""
"""
import numpy as np
import h5py

DTYPE = np.dtype(
    dtype=[
        ("id", "i8"),
        ("descid", "i8"),
        ("upid", "i8"),
        ("flags", "i4"),
        ("uparent_dist", "f4"),
        ("pos", "f4", (6)),
        ("vmp", "f4"),
        ("lvmp", "f4"),
        ("mp", "f4"),
        ("m", "f4"),
        ("v", "f4"),
        ("r", "f4"),
        ("rank1", "f4"),
        ("rank2", "f4"),
        ("ra", "f4"),
        ("rarank", "f4"),
        ("A_UV", "f4"),
        ("sm", "f4"),
        ("icl", "f4"),
        ("sfr", "f4"),
        ("obs_sm", "f4"),
        ("obs_sfr", "f4"),
        ("obs_uv", "f4"),
        ("empty", "f4"),
    ],
    align=True,
)


def load_umachine_sfr_mock(fn):
    return np.fromfile(fn, dtype=DTYPE)


def load_diffmah_fits(fn):
    cat = dict()
    with h5py.File(fn, 'r') as hdf:
        for key in hdf.keys():
            cat[key] = hdf[key][...]
    return cat


def crossmatch_diffmah_fits(um_mock, dmah_cat):
    """Cross-match the file storing diffmah fits against the UM SFR catalog

    Parameters
    ----------
    um_mock : structured ndarray of shape (n_um, )
        Output of the load_umachine_sfr_mock function

    dmah_cat : dict
        Output of the load_diffmah_fits function

    Returns
    -------
    crossmatched_cat : dict
        Dict storing columns from dmah_cat cross-matched against um_mock
        For each `col` in dmah_cat, crossmatched_cat has a key `diffmah_col`
        The ordering of crossmatched_cat matches the ordering in um_mock:
        i^th entry of a column in crossmatched_cat corresponds to i^th entry in um_mock
        crossmatched_cat also has a column `has_diffmah_params`
    """
    crossmatched_cat = dict()
    from halotools.utils import crossmatch
    n_um = len(um_mock['id'])
    idxA, idxB = crossmatch(um_mock['id'], dmah_cat['halo_id'])
    crossmatched_cat['has_diffmah_params'] = np.zeros(n_um).astype(bool)
    crossmatched_cat['has_diffmah_params'][idxA] = True
    for key in dmah_cat.keys():
        crossmatched_cat['diffmah_'+key] = np.zeros(n_um).astype(dmah_cat[key].dtype)
        crossmatched_cat['diffmah_'+key][idxA] = dmah_cat[key][idxB]
    return crossmatched_cat
