import matplotlib.pylab as plt
from mpl_toolkits.basemap import Basemap
import matplotlib.colors as mcolors
import numpy as np
from mpl_toolkits.basemap import Basemap
# plt.style.use('dark_background')
from astropy import units as u
from astropy.coordinates import SkyCoord

def calculate_and_plot_hmf(mass_arrays, labels, volume, num_bins=10, bin_range=None, plot_label=r'$dN/d\log_{10}M$ [1/Mpc$^3$]'):
    """
    Calculate and plot the Halo Mass Function (HMF) for multiple datasets.

    Parameters:
    - mass_arrays (list of np.ndarray): List containing numpy arrays of halo mass values.
    - labels (list of str): List of labels for each dataset, used in the plot legend.
    - volume (float): Survey or simulation volume in Mpc^3.
    - num_bins (int, optional): Number of logarithmic mass bins. Default is 10.
    - bin_range (tuple, optional): Tuple specifying (log10(M_min), log10(M_max)). 
                                   If None, it is determined from the data with padding.
    - plot_label (str, optional): Label for the y-axis in the plot.

    Returns:
    - None. Displays a plot of the HMF and prints the numerical values.
    """
    
    if len(mass_arrays) != len(labels):
        raise ValueError("The number of mass arrays must match the number of labels.")
    
    # Determine the global log mass range if not provided
    if bin_range is None:
        all_log_masses = np.log10(np.concatenate(mass_arrays))
        log_m_min = all_log_masses.min() - 0.5  # padding
        log_m_max = all_log_masses.max() + 0.5  # padding
    else:
        log_m_min, log_m_max = bin_range
    
    # Define logarithmic mass bins
    bins = np.linspace(log_m_min, log_m_max, num_bins + 1)
    bin_centers = 0.5 * (bins[:-1] + bins[1:])
    delta_logM = bins[1] - bins[0]
    
    plt.figure(figsize=(8, 6))
    
    for mass, label in zip(mass_arrays, labels):
        # Convert masses to log10(M)
        log_masses = np.log10(mass)
        
        # Compute histogram
        N_bin, _ = np.histogram(log_masses, bins=bins)
        
        # Calculate HMF
        hmf = N_bin / (delta_logM * volume)  # [1/Mpc^3]
        
        # Calculate uncertainties (Poisson)
        hmf_err = np.sqrt(N_bin) / (delta_logM * volume)
        
        # Print the results
        # print(f"\nHalo Mass Function for '{label}':")
        # print("Log10(Mass) | dN/dlogM [1/Mpc^3] | Uncertainty")
        # for center, mf, err in zip(bin_centers, hmf, hmf_err):
            # print(f"{center:.2f}      | {mf:.3e}          | {err:.3e}")
        
        # Plot the HMF
        plt.errorbar(bin_centers, hmf, yerr=hmf_err, fmt='o-', label=label, capsize=5)
    
    # Configure plot
    plt.xlabel(r'$\log_{10}(M / M_\odot)$', fontsize=14)
    plt.ylabel(plot_label, fontsize=14)
    plt.title('Halo Mass Function', fontsize=16)
    plt.yscale('log')  # Log scale for better visualization
    plt.xlim(log_m_min, log_m_max)
    plt.legend()
    # plt.grid(True, which="both", ls="--", lw=0.5)
    plt.tight_layout()
    plt.show()
    
    
    
def calculate_and_plot_dN_dz(
    redshift_arrays,
    labels,
    volume,
    num_bins=10,
    bin_range=None,
    plot_label=r'$dN/dz$ [1/Mpc$^3$]',
    exclude_non_physical=True
):
    """
    Calculate and plot the halo number density as a function of redshift (dN/dz)
    for multiple datasets.

    Parameters:
    - redshift_arrays (list of np.ndarray): List containing numpy arrays of redshift values.
    - labels (list of str): List of labels for each dataset, used in the plot legend.
    - volume (float): Survey or simulation volume in Mpc^3.
    - num_bins (int, optional): Number of redshift bins. Default is 10.
    - bin_range (tuple, optional): Tuple specifying (z_min, z_max). 
                                   If None, it is determined from the data with padding.
    - plot_label (str, optional): Label for the y-axis in the plot.
    - exclude_non_physical (bool, optional): If True, excludes non-physical redshift values 
                                             (e.g., z < 0).

    Returns:
    - None. Displays a plot of dN/dz and prints the numerical values.
    """

    if len(redshift_arrays) != len(labels):
        raise ValueError("The number of redshift arrays must match the number of labels.")

    cleaned_redshift_arrays = []

    for idx, redshift in enumerate(redshift_arrays):
        if exclude_non_physical:
            # Identify valid (non-negative) redshift values
            valid_mask = redshift >= 0
            num_invalid = np.size(redshift) - np.count_nonzero(valid_mask)
            if num_invalid > 0:
                print(f"Dataset '{labels[idx]}': Excluding {num_invalid} non-physical redshift value(s).")
            # Apply the mask to filter out non-physical redshifts
            redshift = redshift[valid_mask]
            if redshift.size == 0:
                raise ValueError(f"All redshift values in dataset '{labels[idx]}' are non-physical.")
        cleaned_redshift_arrays.append(redshift)

    # Determine the global redshift range if not provided
    if bin_range is None:
        all_redshifts = np.concatenate(cleaned_redshift_arrays)
        z_min = all_redshifts.min() - 0.1  # padding
        z_max = all_redshifts.max() + 0.1  # padding
    else:
        z_min, z_max = bin_range

    # Define redshift bins
    bins = np.linspace(z_min, z_max, num_bins + 1)
    bin_centers = 0.5 * (bins[:-1] + bins[1:])
    delta_z = bins[1] - bins[0]

    plt.figure(figsize=(8, 6))

    for redshift, label in zip(cleaned_redshift_arrays, labels):
        # Compute histogram
        N_bin, _ = np.histogram(redshift, bins=bins)

        # Calculate dN/dz
        dN_dz = N_bin / (delta_z * volume)  # [1/Mpc^3]

        # Calculate uncertainties (Poisson)
        dN_dz_err = np.sqrt(N_bin) / (delta_z * volume)

        # Print the results
        # print(f"\nHalo Number Density as a Function of Redshift for '{label}':")
        # print("Redshift | dN/dz [1/Mpc^3] | Uncertainty")
        # for center, dnz, err in zip(bin_centers, dN_dz, dN_dz_err):
            # print(f"{center:.2f}      | {dnz:.3e}          | {err:.3e}")

        # Plot the dN/dz
        plt.errorbar(bin_centers, dN_dz, yerr=dN_dz_err, fmt='o-', label=label, capsize=5)

    # Configure plot
    plt.xlabel(r'Redshift ($z$)', fontsize=14)
    plt.ylabel(plot_label, fontsize=14)
    plt.title('Halo Number Density as a Function of Redshift', fontsize=16)
    plt.yscale('log')  # Log scale for better visualization
    plt.xlim(z_min, z_max)
    plt.legend()
    # plt.grid(True, which="both", ls="--", lw=0.5)
    plt.tight_layout()
    plt.show()
    

def basemap_plot(ra_sky, dec_sky, color):
    
    fig = plt.figure(figsize=(10, 5))

    ##########################
    ax = fig.add_subplot(111)

    # Only plotting random few galaxies
    n_gal = 2048*8
    random_gal_indices = np.random.randint(low=0, high=dec_sky.shape[0], size=n_gal)

    ra_octant = np.array(ra_sky[random_gal_indices])
    dec_octant = np.array(dec_sky[random_gal_indices])

    # Define the orthographic projection centered on the equator and prime meridian
    m = Basemap(projection='moll', lat_0=0, lon_0=45, resolution='c')
    # Convert RA, Dec to x, y coordinates for plotting
    x, y = m(ra_octant, dec_octant)

    # Plot the sky distribution
    m.scatter(x, y, s=1, c=color, alpha=0.5, edgecolors=None, linewidth=1)

    # Draw parallels and meridians
    m.drawparallels(np.arange(-90.,90.,45), color='yellow', textcolor='yellow', linewidth=2)
    m.drawmeridians(np.arange(0.,360.,45), color='yellow', textcolor='yellow', linewidth=2)
    m.drawmapboundary(fill_color='black')

    plt.suptitle('Sky Distribution of Galaxies in full sky', fontsize=20)
    # return ax
    plt.show()