"""
"""
from jax import jit as jjit
from jax import numpy as jnp
from jax import value_and_grad
from diffmah.utils import jax_adam_wrapper
from diffmah.individual_halo_assembly import _u_rolling_plaw_vs_logt
from diffmah.fit_mah_helpers import get_loss_data, DLOGM_CUT, T_FIT_MIN


def get_bacco_loss_data(
    t_sim, log_mah_sim, lgm_min, dlogm_cut=DLOGM_CUT, t_fit_min=T_FIT_MIN
):
    p_init, loss_data = get_loss_data(t_sim, log_mah_sim, lgm_min, dlogm_cut, t_fit_min)
    logt_target, log_mah_target, logt0, fixed_k, logmp_init = loss_data
    p_init = jnp.array((logmp_init, *p_init))
    return p_init, loss_data[:-1]


@jjit
def _mse(pred, target):
    """Mean square error used to define loss functions."""
    diff = pred - target
    return jnp.mean(diff * diff)


@jjit
def log_mah_mse_loss(params, loss_data):
    """MSE loss function for fitting individual halo growth."""
    logt, log_mah_target, logt0, fixed_k = loss_data
    logm0, logtc, ue, ul = params

    log_mah_pred = _u_rolling_plaw_vs_logt(logt, logt0, logm0, logtc, fixed_k, ue, ul)
    log_mah_loss = _mse(log_mah_pred, log_mah_target)
    return log_mah_loss


log_mah_mse_loss_and_grads = jjit(value_and_grad(log_mah_mse_loss, argnums=0))


def fit_bacco_mah(p_init, loss_data, nstep=200, nt_min=5):
    logt_target, log_mah_target, logt0, u_k = loss_data
    nt = logt_target.size

    if nt < nt_min:
        return None, None, None, None, 0

    _res = jax_adam_wrapper(
        log_mah_mse_loss_and_grads, p_init, loss_data, nstep, n_warmup=1
    )
    p_best, loss_best, loss_arr, params_arr, fit_terminates = _res
    return p_best, loss_best, loss_arr, params_arr, fit_terminates
