import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from scipy.stats import gaussian_kde, spearmanr
import fnmatch

# =======================
# 1. Load IMPROVE dataset
# =======================
improve = xr.open_dataset('aerosol_IMPROVE.nc')

lons = improve['siteloc'][:, 0].astype('float32').values
lats = improve['siteloc'][:, 1].astype('float32').values

# North America filter
in_na = (lats >= 15) & (lats <= 72) & (lons >= -167) & (lons <= -50)
lon_na = lons[in_na]
lat_na = lats[in_na]

# Decode site codes
site_codes = [''.join(c.astype(str)).strip() for c in improve['sitecode'][in_na].values]
site_indices = np.arange(improve.dims['site'])[in_na]

# Build DataFrame
site_df = pd.DataFrame({
    'site_index': site_indices,
    'site_name': site_codes,
    'lat': lat_na,
    'lon': lon_na
})

# =======================
# 2. Load Model Data
# =======================
climo_dir = '/Users/hass877/Work/data_analysis/'

eam_data = xr.open_dataset(climo_dir + 'SSA_Control_PD_ANN_201601_201612_climo.nc')
eam_lon = eam_data['lon'].values
eam_lon[eam_lon > 180.] -= 360.
eam_lat = eam_data['lat'].values

eamxx_data = xr.open_dataset(climo_dir + 'ne256pg2_eamxx_pd_ANN_201901_202012_climo.nc')
eamxx_lon = eamxx_data['lon'].values
eamxx_lon[eamxx_lon > 180.] -= 360.
eamxx_lat = eamxx_data['lat'].values

aer = 'so4'
avar = aer + '_a?'
varlist = fnmatch.filter(list(eam_data.variables), avar)

aer_sfc_eam = eam_data[varlist].isel(lev=71).to_array().sum('variable')
aer_sfc_eamxx = eamxx_data[varlist].isel(lev=127).to_array().sum('variable')

# Conversion factor
factaa = 1.01325e5 / 8.31446261815324 / 273.15 * 28.9647 / 1.e9
factbb = factaa * 1.e15

# =======================
# 3. Nearest Grid Function
# =======================
def get_nearest_ncol(model_lon, model_lat, site_lon, site_lat):
    distsq = (model_lon - site_lon)**2 + (model_lat - site_lat)**2
    return np.argmin(distsq)

# =======================
# 4. Extract Model Values at Site Locations
# =======================
eam_model_vals = []
eamxx_model_vals = []

for _, row in site_df.iterrows():
    site_lon, site_lat = row['lon'], row['lat']
    
    idx_eam = get_nearest_ncol(eam_lon, eam_lat, site_lon, site_lat)
    idx_eamxx = get_nearest_ncol(eamxx_lon, eamxx_lat, site_lon, site_lat)

    eam_model_vals.append(aer_sfc_eam.values[idx_eam] * factbb)
    eamxx_model_vals.append(aer_sfc_eamxx.values[idx_eamxx] * factbb)

site_df[f'{aer}_EAM_model'] = eam_model_vals
site_df[f'{aer}_EAMxx_model'] = eamxx_model_vals

# =======================
# 5. Add Observations
# =======================
species = 'SO4f'
species_attrs = improve['conc'].attrs
species_names = [species_attrs.get(f'_speciesname{i}', f'species_{i}') for i in range(improve.dims['species'])]
species_idx = species_names.index(species)
year_idx = 1  # second year

obs_vals = []
for site in site_df['site_index']:
    conc = improve['conc'][site, species_idx, year_idx, :, :].values
    conc = np.where(conc > 9e36, np.nan, conc)  # remove FillValue
    obs_vals.append(np.nanmean(conc))

site_df[f'obs_conc_{species}'] = obs_vals

# =======================
# 6. Observation vs Model Plot
# =======================
obs = site_df[f'obs_conc_{species}'].values
eam = site_df[f'{aer}_EAM_model'].values
eamxx = site_df[f'{aer}_EAMxx_model'].values

mask = (~np.isnan(obs)) & (~np.isnan(eam)) & (~np.isnan(eamxx)) & (obs > 0) & (eam > 0) & (eamxx > 0)
obs_clean = obs[mask]
eam_clean = eam[mask]
eamxx_clean = eamxx[mask]

def shared_log_limits(*datasets, buffer=0.2):
    all_data = np.concatenate([d[d > 0] for d in datasets])
    log_min = np.floor(np.log10(all_data.min()))
    log_max = np.ceil(np.log10(all_data.max()))
    return 10**(log_min - buffer), 10**(log_max + buffer)

lim_min, lim_max = shared_log_limits(obs_clean, eam_clean, eamxx_clean)

fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharex=True, sharey=True)

panels = [
    (obs_clean, eam_clean, 'EAM vs IMPROVE'),
    (obs_clean, eamxx_clean, 'EAMxx vs IMPROVE'),
    (eam_clean, eamxx_clean, 'EAMxx vs EAM')
]

for ax, (x_data, y_data, title) in zip(axes, panels):
    xy = np.vstack([x_data, y_data])
    z = gaussian_kde(xy)(xy)
    idx = z.argsort()

    ax.scatter(x_data[idx], y_data[idx], c=z[idx], s=40, cmap='viridis',
               edgecolors='k', linewidth=0.3)

    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlim(lim_min, lim_max)
    ax.set_ylim(lim_min, lim_max)

    # 1:1 and factor-of-2 lines
    ax.add_line(mlines.Line2D([lim_min, lim_max], [lim_min, lim_max], color='k', linestyle='-', linewidth=1))
    ax.add_line(mlines.Line2D([2 * lim_min, lim_max], [lim_min, 0.5 * lim_max], color='k', linestyle='--', linewidth=1))
    ax.add_line(mlines.Line2D([lim_min, 0.5 * lim_max], [2 * lim_min, lim_max], color='k', linestyle='--', linewidth=1))

    ax.set_title(title, fontsize=14)
    ax.set_xlabel("X [µg/m³]", fontsize=12)
    if ax is axes[0]:
        ax.set_ylabel("Y [µg/m³]", fontsize=12)

    # Stats
    rho, _ = spearmanr(x_data, y_data)
    meanB = abs(((y_data - x_data).sum() / x_data.sum()) * 100)
    n = len(x_data)
    stat_text = f"ρ = {rho:.3f}\nBias = {meanB:.1f}%\nN = {n}"
    ax.text(0.05, 0.95, stat_text, transform=ax.transAxes,
            fontsize=11, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

fig.suptitle(f'{aer.upper()} Comparison at IMPROVE Sites', fontsize=18, y=1.05)
plt.tight_layout()
#plt.savefig(f"SO4_scatter_OBS_EAM_EAMxx.png", dpi=300, bbox_inches='tight', pad_inches=0.1)
# plt.show()
