import os
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib.colors as colors
from matplotlib.backends.backend_pdf import PdfPages

# Set up paths and variables
datapath = "./"
varnames = ["AODVIS", "colccn.3"]
nvar = len(varnames)
filenames = ["v2_ndg_cdnc_ssat_diag_simple_pi.eam.h0.2011-01.nc", 
             "v2_ndg_cdnc_ssat_diag_simple_pd.eam.h0.2011-01.nc"]
nfile = len(filenames)
nPlotColumn = 5

# Set up PDF for multiple plots
pdf = PdfPages('./fig_compare_two_runs_Python_mj_test.pdf')

# Get lat/lon info from the first file and check consistency between files
lat, lon = None, None
ncol = None

for ifile in range(nfile):
    ds = xr.open_dataset(os.path.join(datapath, filenames[ifile]))
    if ifile == 0:
        area = ds['area'].values
        ncol = len(area)
        lon = ds['lon'].values  # These are 1D arrays for unstructured grid
        lat = ds['lat'].values
    else:
        ncol1 = len(ds['area'].values)
        if ncol1 != ncol:
            print("Error: two files have different numbers of grid columns. Abort.")
            exit()

# Set up plotting configurations
FillValue = -999.  # for avoiding division by zero

# Function to create symmetric colorbar limits
def sym_lim(data):
    vmax = max(abs(np.nanmin(data)), abs(np.nanmax(data)))
    return -vmax, vmax

for ivar in range(nvar):
    array4plotting = np.zeros((nPlotColumn, ncol))
    
    for ifile in range(nfile):
        ds = xr.open_dataset(os.path.join(datapath, filenames[ifile]))
        array4plotting[ifile, :] = ds[varnames[ivar]].values
    
    # Calculate differences
    # Difference: test - ctrl
    array4plotting[2, :] = array4plotting[1, :] - array4plotting[0, :]
    
    # Relative difference: (test - ctrl)/ctrl
    tmp = array4plotting[0, :].copy()
    denom = np.where(tmp != 0, tmp, FillValue)
    array4plotting[3, :] = np.where(denom != FillValue, 
                                    (array4plotting[1, :] - array4plotting[0, :]) / denom,
                                    FillValue)
    
    # Relative difference: (test - ctrl)*2/(test + ctrl)
    tmp = 0.5 * (array4plotting[0, :] + array4plotting[1, :])
    denom = np.where(tmp != 0, tmp, FillValue)
    array4plotting[4, :] = np.where(denom != FillValue,
                                    (array4plotting[1, :] - array4plotting[0, :]) / denom,
                                    FillValue)
    
    # Create figure for this variable
    fig = plt.figure(figsize=(18, 4))
    fig.suptitle(varnames[ivar], fontsize=14)
    
    titles = ["Ctrl", "Test", 
              "Difference: test - ctrl", 
              "Relative difference: (test - ctrl)/ctrl", 
              "Relative difference: (test - ctrl)*2/(test + ctrl)"]
              
    cmaps = ["inferno", "inferno", "RdBu_r", "RdBu_r", "RdBu_r"]
    
    for i in range(nPlotColumn):
        ax = plt.subplot( 1, nPlotColumn, i+1, 
                          projection=ccrs.Mollweide() )
                         
        # For unstructured grid, we need to use scatter plot
        # Set colormap and normalization
        if i >= 2:
            # Symmetric colormap for difference plots
            vmin, vmax = sym_lim(array4plotting[i, :])
            norm = colors.Normalize(vmin=vmin, vmax=vmax)
        else:
            vmin = np.nanmin(array4plotting[i, :])
            vmax = np.nanmax(array4plotting[i, :])
            norm = colors.Normalize(vmin=vmin, vmax=vmax)
        
        # Create scatter plot for unstructured grid
        sc = ax.scatter(lon, lat, 
                        c=array4plotting[i, :], 
                        s=2,  # Adjust point size as needed
                        transform=ccrs.PlateCarree(),
                        cmap=cmaps[i], 
                        norm=norm)
                        
        # Add colorbar
        plt.colorbar(sc, ax=ax, orientation='horizontal', pad=0.05, shrink=0.8)
        
        # Add coastlines and gridlines
        ax.coastlines()
        ax.gridlines(linestyle='--')
        ax.set_global()
        
        # Set title
        ax.set_title(titles[i], fontsize=10)
    
    plt.tight_layout()
    pdf.savefig(fig)
    plt.close()

pdf.close()
print(f"Plots saved to fig_compare_two_runs.pdf")
