"""Script to fit Bolshoi, MDPL2, or TNG MAHs with the diffmah model."""
import numpy as np
import os
from mpi4py import MPI
import argparse
from time import time

from diffmah.fit_mah_helpers import get_header, get_outline_bad_fit
from diffmah.fit_mah_helpers import get_loss_data
from diffmah.fit_mah_helpers import get_outline
from diffmah.fit_mah_helpers import diffmah_fitter

from load_bacco_data import load_bacco_mahs
import subprocess
import h5py

TMP_OUTPAT = "_tmp_mah_fits_rank_{0}.dat"
TODAY = 13.8


def _write_collated_data(outname, data):
    nrows, ncols = np.shape(data)
    colnames = get_header()[1:].strip().split()
    assert len(colnames) == ncols, "data mismatched with header"
    with h5py.File(outname, "w") as hdf:
        for i, name in enumerate(colnames):
            if name == "halo_id":
                hdf[name] = data[:, i].astype("i8")
            else:
                hdf[name] = data[:, i]


if __name__ == "__main__":
    comm = MPI.COMM_WORLD
    rank, nranks = comm.Get_rank(), comm.Get_size()

    parser = argparse.ArgumentParser()

    parser.add_argument("fname", help="Filename storing BACCO MAH data")
    parser.add_argument("outdir", help="Output directory")
    parser.add_argument("-nstep", help="Num opt steps per halo", type=int, default=200)
    parser.add_argument("-test", help="Short test run?", type=bool, default=False)
    parser.add_argument(
        "-fittol", help="Tolerance parameter of the fitter", type=float, default=0.001
    )

    args = parser.parse_args()
    bname = os.path.basename(args.fname)
    outbase = bname.replace(".h5", ".fits.h5")
    rank_basepat = outbase + TMP_OUTPAT
    rank_outname = os.path.join(args.outdir, rank_basepat).format(rank)
    nstep = args.nstep
    fittol = args.fittol

    _mah_data = load_bacco_mahs(args.fname)
    halo_ids, log_mahs, tarr, lgm_min = _mah_data

    # Ensure the target MAHs are cumulative peak masses
    log_mahs = np.maximum.accumulate(log_mahs, axis=1)

    if args.test:
        nhalos_tot = nranks * 5
    else:
        nhalos_tot = len(halo_ids)

    # Get data for rank
    _a = np.arange(0, nhalos_tot).astype("i8")
    indx = np.array_split(_a, nranks)[rank]
    # assert False, (nranks, nhalos_tot, indx)

    halo_ids_for_rank = halo_ids[indx]
    log_mahs_for_rank = log_mahs[indx]
    nhalos_for_rank = len(halo_ids_for_rank)

    start = time()

    header = get_header()
    with open(rank_outname, "w") as fout:
        fout.write(header)

        for i in range(nhalos_for_rank):
            halo_id = halo_ids_for_rank[i]
            lgmah = log_mahs_for_rank[i, :]

            p_init, loss_data = get_loss_data(
                tarr,
                lgmah,
                lgm_min,
            )
            _res = diffmah_fitter(
                p_init,
                loss_data,
                n_adam_step=nstep,
                n_adam_warmup=1,
                tol=fittol,
            )
            p_best, loss_best, loss_arr, params_arr, fit_terminates = _res

            if fit_terminates == 1:
                outline = get_outline(halo_id, loss_data, p_best, loss_best)
            else:
                outline = get_outline_bad_fit(halo_id, lgmah[-1], TODAY)

            fout.write(outline)

    comm.Barrier()
    end = time()

    msg = (
        "\n\nWallclock runtime to fit {0} galaxies with {1} ranks = {2:.1f} seconds\n\n"
    )
    if rank == 0:
        runtime = end - start
        print(msg.format(nhalos_tot, nranks, runtime))

        #  collate data from ranks and rewrite to disk
        pat = os.path.join(args.outdir, rank_basepat)
        fit_data_fnames = [pat.format(i) for i in range(nranks)]
        data_collection = [np.loadtxt(fn) for fn in fit_data_fnames]
        all_fit_data = np.concatenate(data_collection)
        outname = os.path.join(args.outdir, outbase)
        _write_collated_data(outname, all_fit_data)

        #  clean up temporary files
        _remove_basename = pat.replace("{0}", "*")
        command = "rm -rf " + _remove_basename
        raw_result = subprocess.check_output(command, shell=True)
