#!/usr/bin/env python3
"""
E3SM S2S2D Fig. 10-style land analysis using the S2D GP spinup cycle-6
land history as the reference reconstruction.

Designed for the user's Perlmutter layout:

Hindcasts:
  /global/cfs/cdirs/e3smdata/simulations/S2S2D/
    WCYCL20TR_ne30pg2_r05_IcoswISC30E3r5_JRA55_FOSIRL_YYYYMM0100/
      EN00/archive/lnd/hist/*.elm.h0.YYYY-MM.nc
      ...
      EN09/archive/lnd/hist/*.elm.h0.YYYY-MM.nc

Reference:
  /global/cfs/cdirs/e3sm/lvroekel/archive/lnd/hist/
    20251024_s2d_spinup.elm.h0.0001-01.nc
    ...
    20251024_s2d_spinup.elm.h0.0392-12.nc

Reference calendar mapping:
  The GP spinup repeats 1958-2023 for six cycles.
  Cycle 6 begins at model year 0331 = calendar year 1958.
  Therefore calendar_year = model_year - 331 + 1958.
  For example, model year 0353 = 1980, and 0392 = 2019.

Example:
  python e3sm_s2s2d_f10_tws_gp_ref.py \
    --base /global/cfs/cdirs/e3smdata/simulations/S2S2D \
    --ref-dir /global/cfs/cdirs/e3sm/lvroekel/archive/lnd/hist \
    --var TWS \
    --start-years 1980 2018 \
    --start-months 5 11 \
    --members 0 9 \
    --climy0 1981 --climy1 2017 \
    --skill-year0 1981 --skill-year1 2019 \
    --target-start-month 11 \
    --target-season JJA \
    --target-lead-month 19 \
    --output e3sm_s2s2d_f10_tws_gp_ref.png

Notes:
  * Because the GP reference currently ends around model year 0392/0393-01
    (calendar 2019 or Jan 2020), lead/verification seasons that need later
    months, such as JJA 2020 from the 2018-11 start, will be skipped as NaN.
  * This script uses native-grid scatter maps instead of conservative regridding.
"""

from __future__ import annotations

import argparse
import glob
import os
import re
from pathlib import Path
from typing import Iterable

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle


SEASON_BY_CENTER_MONTH = {1: "DJF", 4: "MAM", 7: "JJA", 10: "SON"}
SEASON_CENTER_MONTHS = np.array([1, 4, 7, 10])
SEASON_CENTER_MONTH = {"DJF": 1, "MAM": 4, "JJA": 7, "SON": 10}


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="E3SM S2S2D Fig. 10-style TWS skill analysis with GP spinup reference"
    )

    # Hindcast layout
    p.add_argument("--base", default="/global/cfs/cdirs/e3smdata/simulations/S2S2D",
                   help="Root directory containing all S2S2D initialization case directories.")
    p.add_argument("--case-prefix", default="WCYCL20TR_ne30pg2_r05_IcoswISC30E3r5_JRA55_FOSIRL",
                   help="Hindcast case prefix before _YYYYMM0100.")
    p.add_argument("--start-years", nargs=2, type=int, default=[1980, 2018],
                   metavar=("YEAR0", "YEAR1"), help="Initialization year range, inclusive.")
    p.add_argument("--start-months", nargs="+", type=int, default=[5, 11],
                   help="Initialization months available in your experiment, e.g., 5 11.")
    p.add_argument("--members", nargs=2, type=int, default=[0, 9],
                   metavar=("MEM0", "MEM1"), help="Ensemble member range, inclusive.")
    p.add_argument("--allow-missing-members", action="store_true",
               help="Skip unreadable/missing ensemble members instead of stopping.")
    p.add_argument("--member-list", nargs="+", type=int, default=None,
               help="Explicit ensemble member list, e.g. --member-list 0 1 2 3 5 6 7 8 9. "
                    "If not set, use --members MEM0 MEM1.")
    p.add_argument("--stream", default="h0", help="ELM history stream, normally h0.")
    p.add_argument("--nlead-months", type=int, default=24,
                   help="Number of monthly lead times to read from each hindcast.")

    # Variable and cache
    p.add_argument("--var", default="TWS", help="ELM variable to analyze.")
    p.add_argument("--ref-var", default=None, help="Reference variable name. Default: same as --var.")
    p.add_argument("--cache", default="./e3sm_s2s2d_cache",
                   help="Directory for cached seasonal hindcast files.")
    p.add_argument("--rebuild-cache", action="store_true",
                   help="Rebuild cached seasonal hindcast files.")

    # GP spinup reference
    p.add_argument("--ref-dir", default="/global/cfs/cdirs/e3sm/lvroekel/archive/lnd/hist",
                   help="Directory containing GP spinup ELM h0 reference files.")
    p.add_argument("--ref-case", default="20251024_s2d_spinup",
                   help="GP spinup case name used in reference file names.")
    p.add_argument("--ref-pattern", default=None,
                   help="Optional explicit glob for reference files. If omitted, uses "
                        "'REF_DIR/REF_CASE.elm.STREAM.*.nc'.")
    p.add_argument("--ref-cycle-start-model-year", type=int, default=331,
                   help="Model year corresponding to the start of cycle 6. Default: 331.")
    p.add_argument("--ref-cycle-start-calendar-year", type=int, default=1958,
                   help="Calendar year corresponding to ref-cycle-start-model-year. Default: 1958.")
    p.add_argument("--ref-calendar-year0", type=int, default=1980,
                   help="First calendar year to keep from reference.")
    p.add_argument("--ref-calendar-year1", type=int, default=2019,
                   help="Last full calendar year to keep from reference. Use 2019 if ref ends at 0392-12.")

    # Skill options
    p.add_argument("--climy0", type=int, default=1981, help="First climatology year.")
    p.add_argument("--climy1", type=int, default=2017, help="Last climatology year.")
    p.add_argument("--skill-year0", type=int, default=1981,
                   help="First verification calendar year used for skill.")
    p.add_argument("--skill-year1", type=int, default=2019,
                   help="Last verification calendar year used for skill. With GP ref ending 0392-12, use 2019.")

    # Figure options
    p.add_argument("--target-start-month", type=int, default=11,
                   help="Start month used for map/time-series panels.")
    p.add_argument("--target-season", default="JJA", choices=["DJF", "MAM", "JJA", "SON"],
                   help="Target season for panels B-D.")
    p.add_argument("--target-lead-month", type=int, default=19,
                   help="Target lead month for map/time-series panel. For NOV-to-second-JJA use 19.")
    p.add_argument("--region", nargs=4, type=float, default=[-120.0, -100.0, 22.0, 37.0],
                   metavar=("LON_W", "LON_E", "LAT_S", "LAT_N"),
                   help="Region box. Western longitudes can be negative.")
    p.add_argument("--region-name", default="Southwest", help="Region label.")
    p.add_argument("--output", default="e3sm_s2s2d_f10_tws_gp_ref.png", help="Output PNG file.")
    p.add_argument("--dpi", type=int, default=150)
    return p.parse_args()


def case_dir(base: str, prefix: str, year: int, month: int) -> str:
    return os.path.join(base, f"{prefix}_{year:04d}{month:02d}0100")


def member_name(m: int) -> str:
    return f"EN{m:02d}"


def parse_hindcast_yyyy_mm(path: str):
    m = re.search(r"\.(\d{4})-(\d{2})\.nc$", os.path.basename(path))
    if not m:
        return None
    return int(m.group(1)), int(m.group(2))


def parse_gp_model_yyyy_mm(path: str):
    """Parse model year/month from GP spinup files like elm.h0.0353-01.nc."""
    m = re.search(r"\.(\d{4})-(\d{2})\.nc$", os.path.basename(path))
    if not m:
        return None
    return int(m.group(1)), int(m.group(2))


def model_year_to_calendar_year(model_year: int, start_model_year: int, start_calendar_year: int) -> int:
    return model_year - start_model_year + start_calendar_year


def find_monthly_files(base: str, prefix: str, year: int, month: int, mem: int, stream: str) -> list[str]:
    en = member_name(mem)
    hist = os.path.join(case_dir(base, prefix, year, month), en, "archive", "lnd", "hist")
    pattern = os.path.join(hist, f"{prefix}_{year:04d}{month:02d}0100.{en}.elm.{stream}.*.nc")
    files = sorted(glob.glob(pattern))
    return [f for f in files if re.search(r"\.elm\." + re.escape(stream) + r"\.\d{4}-\d{2}\.nc$", f)]


def attach_static_coords(da: xr.DataArray, ds: xr.Dataset) -> xr.DataArray:
    for cname in ["lat", "lon", "area", "landfrac", "LONGXY", "LATIXY"]:
        if cname in ds and cname not in da.coords:
            da = da.assign_coords({cname: ds[cname]})
    return da


def open_monthly_var(files: list[str], var: str, nlead: int) -> xr.DataArray:
    """Open hindcast monthly files and set time from YYYY-MM in file names."""
    if len(files) == 0:
        raise FileNotFoundError("No monthly hindcast files found.")
    files = files[:nlead]
    pieces = []
    for f in files:
        ym = parse_hindcast_yyyy_mm(f)
        if ym is None:
            raise ValueError(f"Cannot parse hindcast YYYY-MM from file name: {f}")
        y, m = ym
        ds = xr.open_dataset(f, decode_times=False)
        if var not in ds:
            raise KeyError(f"Variable {var!r} not found in {f}")
        da = ds[var]
        if "time" in da.dims:
            da = da.isel(time=0, drop=True)
        da = attach_static_coords(da, ds)
        da = da.expand_dims(time=[np.datetime64(f"{y:04d}-{m:02d}-15")])
        pieces.append(da)
    return xr.concat(pieces, dim="time")


def seasonal_means_from_monthly(da: xr.DataArray) -> xr.DataArray:
    """Centered 3-month means; retain DJF/MAM/JJA/SON center months."""
    seas = da.rolling(time=3, center=True, min_periods=3).mean().dropna("time", how="all")
    keep = np.isin(seas["time"].dt.month, SEASON_CENTER_MONTHS)
    seas = seas.isel(time=keep)

    center_times = seas["time"].values
    center_months = seas["time"].dt.month.values.astype(int)
    seasons = [SEASON_BY_CENTER_MONTH[int(m)] for m in center_months]

    seas = seas.rename({"time": "lead"})
    seas = seas.assign_coords(
        lead=np.arange(seas.sizes["lead"]),
        center_time=("lead", center_times),
        center_month=("lead", center_months),
        season=("lead", seasons),
    )
    return seas


def lead_month_for_time(init_year: int, init_month: int, center_time) -> int:
    ts = xr.DataArray(center_time)
    y = int(ts.dt.year.values)
    center_month = int(ts.dt.month.values)
    season_start_month = {1: 12, 4: 3, 7: 6, 10: 9}[center_month]
    season_start_year = y - 1 if center_month == 1 else y
    return (season_start_year - init_year) * 12 + (season_start_month - init_month)


def build_one_start_month(args: argparse.Namespace, start_month: int) -> xr.DataArray:
    y0, y1 = args.start_years
    years = range(y0, y1 + 1)
    if args.member_list is not None:
        members = list(args.member_list)
    else:
        m0, m1 = args.members
        members = list(range(m0, m1 + 1))

    by_year = []
    lead_months_first = None
    seasons_first = None

    for year in years:
        by_member = []
        for mem in members:
            files = find_monthly_files(args.base, args.case_prefix, year, start_month, mem, args.stream)
            if len(files) < args.nlead_months:
                msg = f"{year}-{start_month:02d} {member_name(mem)} has only {len(files)} monthly files; expected {args.nlead_months}"
                if args.allow_missing_members:
                    print(f"WARNING: {msg}. Skipping this member for this start.")
                    continue
                raise FileNotFoundError(msg + " Use --allow-missing-members to skip it, or --member-list to exclude unavailable members.")
            da_mon = open_monthly_var(files, args.var, args.nlead_months)
            da_seas = seasonal_means_from_monthly(da_mon)

            times = da_seas["center_time"].values
            lead_months = [lead_month_for_time(year, start_month, tt) for tt in times]
            seasons = [SEASON_BY_CENTER_MONTH[int(xr.DataArray(tt).dt.month.values)] for tt in times]
            valid_year = [int(xr.DataArray(tt).dt.year.values) for tt in times]
            valid_month = [int(xr.DataArray(tt).dt.month.values) for tt in times]

            da_seas = da_seas.assign_coords(
                lead=np.arange(da_seas.sizes["lead"]),
                lead_month=("lead", lead_months),
                season=("lead", seasons),
                valid_year=("lead", valid_year),
                valid_month=("lead", valid_month),
            )

            if lead_months_first is None:
                lead_months_first = lead_months
                seasons_first = seasons

            by_member.append(da_seas.expand_dims(member=[mem]))

        if len(by_member) == 0:
            raise RuntimeError(f"No usable members for start {year}-{start_month:02d}. Check permissions/files.")
        da_y = xr.concat(by_member, dim="member").expand_dims(start_year=[year])
        by_year.append(da_y)
        print(f"Built seasonal hindcast data for start {year}-{start_month:02d}")

    out = xr.concat(by_year, dim="start_year")
    out = out.assign_coords(
        lead_month=("lead", lead_months_first),
        season=("lead", seasons_first),
    )
    return out


def build_or_read_hindcasts(args: argparse.Namespace) -> dict[int, xr.DataArray]:
    cache = Path(args.cache)
    cache.mkdir(parents=True, exist_ok=True)
    out = {}
    for sm in args.start_months:
        f = cache / f"{args.case_prefix}_start{sm:02d}_{args.var}_seasonal.nc"
        if f.exists() and not args.rebuild_cache:
            print(f"Reading cached hindcast seasonal file: {f}")
            out[sm] = xr.open_dataset(f)[args.var]
        else:
            da = build_one_start_month(args, sm)
            da.to_dataset(name=args.var).to_netcdf(f)
            print(f"Wrote cached hindcast seasonal file: {f}")
            out[sm] = da
    return out


def read_gp_reference(args: argparse.Namespace) -> xr.DataArray:
    """Read GP spinup model-year files and remap cycle-6 model years to calendar years."""
    var = args.ref_var or args.var
    pattern = args.ref_pattern
    if pattern is None:
        pattern = os.path.join(args.ref_dir, f"{args.ref_case}.elm.{args.stream}.*.nc")

    files = sorted(glob.glob(pattern))
    files = [f for f in files if parse_gp_model_yyyy_mm(f) is not None]
    if not files:
        raise FileNotFoundError(f"No GP reference monthly files found with pattern: {pattern}")

    pieces = []
    kept = 0
    for f in files:
        model_year, month = parse_gp_model_yyyy_mm(f)
        cal_year = model_year_to_calendar_year(
            model_year,
            args.ref_cycle_start_model_year,
            args.ref_cycle_start_calendar_year,
        )

        if cal_year < args.ref_calendar_year0 or cal_year > args.ref_calendar_year1:
            continue

        ds = xr.open_dataset(f, decode_times=False)
        if var not in ds:
            raise KeyError(f"Reference variable {var!r} not found in {f}")
        da = ds[var]
        if "time" in da.dims:
            da = da.isel(time=0, drop=True)
        da = attach_static_coords(da, ds)
        da = da.expand_dims(time=[np.datetime64(f"{cal_year:04d}-{month:02d}-15")])
        pieces.append(da)
        kept += 1

    if not pieces:
        raise RuntimeError("No reference files were kept after calendar-year filtering. "
                           "Check --ref-cycle-start-model-year/year and --ref-calendar-year0/1.")

    ref = xr.concat(pieces, dim="time").sortby("time")
    print(f"Read GP reference: {kept} monthly files from {str(ref.time.values[0])[:10]} "
          f"to {str(ref.time.values[-1])[:10]}")

    ref_seas = ref.rolling(time=3, center=True, min_periods=3).mean().dropna("time", how="all")
    ref_seas = ref_seas.isel(time=np.isin(ref_seas.time.dt.month, SEASON_CENTER_MONTHS))
    print(f"Built GP reference seasonal means: {ref_seas.sizes.get('time', 0)} seasons")
    return ref_seas


def get_spatial_dims(da: xr.DataArray) -> list[str]:
    return [d for d in da.dims if d not in ("start_year", "member", "lead", "time")]


def detrend_1d(y: np.ndarray) -> np.ndarray:
    x = np.arange(y.size, dtype=float)
    mask = np.isfinite(y)
    if mask.sum() < 3:
        return y * np.nan
    coef = np.polyfit(x[mask], y[mask], 1)
    return y - np.polyval(coef, x)


def detrend_dim(da: xr.DataArray, dim: str) -> xr.DataArray:
    out = xr.apply_ufunc(
        detrend_1d, da,
        input_core_dims=[[dim]], output_core_dims=[[dim]],
        vectorize=True, dask="parallelized", output_dtypes=[float],
    )
    return out.transpose(*da.dims)


def remove_hindcast_drift(da: xr.DataArray, climy0: int, climy1: int) -> xr.DataArray:
    clim = da.sel(start_year=slice(climy0, climy1)).mean(("start_year", "member"), skipna=True)
    return da - clim


def ref_anomaly_at_valid_times(ref: xr.DataArray, hind: xr.DataArray, climy0: int, climy1: int) -> xr.DataArray:
    """
    Match continuous reference seasonal means to each hindcast valid year/month.

    This version handles hind.valid_year and hind.valid_month whether they are:
      1) 1D coordinates with dimension lead, or
      2) 2D coordinates with dimensions start_year, lead.
    """
    pieces = []

    for ilead in range(hind.sizes["lead"]):
        vals = []

        # Reference target month for this lead.
        # If valid_month is 2D, use the first start_year as template.
        if "start_year" in hind.valid_month.dims:
            vm0 = int(hind.valid_month.isel(start_year=0, lead=ilead).values)
        else:
            vm0 = int(hind.valid_month.isel(lead=ilead).values)

        # Reference climatology for this target month.
        clim = ref.where(
            (ref.time.dt.year >= climy0)
            & (ref.time.dt.year <= climy1)
            & (ref.time.dt.month == vm0),
            drop=True,
        ).mean("time", skipna=True)

        for isy, sy in enumerate(hind.start_year.values):
            sy_int = int(sy)

            # Valid year for this start year and lead.
            if "start_year" in hind.valid_year.dims:
                vy = int(hind.valid_year.isel(start_year=isy, lead=ilead).values)
            else:
                # Older/cache format: valid_year only varies by lead;
                # shift by start year relative to the first start year.
                vy0 = int(hind.valid_year.isel(lead=ilead).values)
                sy0 = int(hind.start_year.values[0])
                vy = vy0 + (sy_int - sy0)

            sel = ref.where(
                (ref.time.dt.year == vy) & (ref.time.dt.month == vm0),
                drop=True,
            )

            if sel.sizes.get("time", 0) == 0:
                # Missing verification time, e.g., JJA 2020 if GP ref ends in 2019.
                arr = xr.full_like(
                    hind.isel(start_year=isy, member=0, lead=ilead),
                    np.nan,
                )
            else:
                arr = sel.isel(time=0, drop=True)

            vals.append(arr.expand_dims(start_year=[sy_int]))
        
        
        lead_ref = xr.concat(vals, dim="start_year")
        lead_ref = lead_ref - clim
        lead_ref = lead_ref.expand_dims(lead=[ilead])

        # Drop lead-specific auxiliary coordinates that can conflict across leads
        # during xr.concat(pieces, dim="lead").
        drop_coords = [
            "center_time",
            "center_month",
            "valid_year",
            "valid_month",
            "lead_month",
            "season",
            "time",
        ]
        existing_drop = [c for c in drop_coords if c in lead_ref.coords]
        if existing_drop:
            lead_ref = lead_ref.drop_vars(existing_drop, errors="ignore")

        pieces.append(lead_ref)

    out = xr.concat(pieces, dim="lead", compat="override", coords="minimal")
    out = out.assign_coords(start_year=hind.start_year, lead=hind.lead)

    for cname in ["lead_month", "season", "valid_month"]:
        if cname in hind.coords:
            if "start_year" in hind[cname].dims:
                out = out.assign_coords({cname: hind[cname].isel(start_year=0)})
            else:
                out = out.assign_coords({cname: hind[cname]})

    return out

def corr_pval_rmse(fore: xr.DataArray, obs: xr.DataArray, dim="start_year") -> xr.Dataset:
    """ACC and normalized RMSE after detrending along dim."""
    f = fore.mean("member", skipna=True) if "member" in fore.dims else fore
    o = obs

    valid_pair = np.isfinite(f) & np.isfinite(o)
    f = f.where(valid_pair)
    o = o.where(valid_pair)

    f_dt = detrend_dim(f, dim)
    o_dt = detrend_dim(o, dim)

    f_an = f_dt - f_dt.mean(dim, skipna=True)
    o_an = o_dt - o_dt.mean(dim, skipna=True)

    corr = (f_an * o_an).mean(dim, skipna=True) / (
        f_an.std(dim, skipna=True) * o_an.std(dim, skipna=True)
    )
    rmse = np.sqrt(((f_dt - o_dt) ** 2).mean(dim, skipna=True)) / o_dt.std(dim, skipna=True)

    try:
        from scipy import stats as scipy_stats
        n = valid_pair.astype(int).sum(dim)
        t = corr * np.sqrt((n - 2) / (1 - corr ** 2))
        pval = xr.apply_ufunc(
            lambda x, nn: 2 * scipy_stats.t.sf(np.abs(x), nn - 2) if nn > 2 and np.isfinite(x) else np.nan,
            t, n, vectorize=True, dask="parallelized", output_dtypes=[float],
        )
    except Exception:
        n = valid_pair.astype(int).sum(dim)
        pval = xr.full_like(corr, np.nan)

    return xr.Dataset({"corr": corr, "rmse": rmse, "pval": pval, "n": n})


def lon_to_360(lon):
    return xr.where(lon < 0, lon + 360, lon)


def get_lon_lat(da: xr.DataArray):
    lon_name = "lon" if "lon" in da.coords else "LONGXY" if "LONGXY" in da.coords else None
    lat_name = "lat" if "lat" in da.coords else "LATIXY" if "LATIXY" in da.coords else None
    if lon_name is None or lat_name is None:
        raise KeyError("Could not find lon/lat coordinates. Expected lon/lat or LONGXY/LATIXY.")
    lon = lon_to_360(da[lon_name])
    lat = da[lat_name]
    return lon, lat


def region_mask(da: xr.DataArray, region: Iterable[float]) -> xr.DataArray:
    lon_w, lon_e, lat_s, lat_n = region
    lon0 = lon_w if lon_w >= 0 else lon_w + 360
    lon1 = lon_e if lon_e >= 0 else lon_e + 360
    lon, lat = get_lon_lat(da)
    if lon0 <= lon1:
        m_lon = (lon >= lon0) & (lon <= lon1)
    else:
        m_lon = (lon >= lon0) | (lon <= lon1)
    return m_lon & (lat >= lat_s) & (lat <= lat_n)


def spatial_weight(da: xr.DataArray) -> xr.DataArray:
    _, lat = get_lon_lat(da)
    if "area" in da.coords:
        return da["area"]
    return np.cos(np.deg2rad(lat))


def regional_mean(da: xr.DataArray, region: Iterable[float]) -> xr.DataArray:
    mask = region_mask(da, region)

    # Make sure mask has no NaN values.
    mask = mask.fillna(False)

    # xarray.weighted() does not allow NaN in weights.
    # Set weights outside the target region to 0.
    w = spatial_weight(da).where(mask, 0.0).fillna(0.0)

    spatial_dims = get_spatial_dims(da)

    return da.where(mask).weighted(w).mean(spatial_dims, skipna=True)

def select_lead(obj: xr.DataArray | xr.Dataset, target_season: str, target_lead_month: int) -> int:
    idx = np.where((obj["season"].values == target_season)
                   & (obj["lead_month"].values == target_lead_month))[0]
    if len(idx) == 0:
        raise ValueError(
            f"No lead found for season={target_season}, lead_month={target_lead_month}. "
            f"Available: {list(zip(obj['season'].values, obj['lead_month'].values))}"
        )
    return int(idx[0])


def plot(args, hind_by_sm, skill_by_sm, ref_by_sm, hind_anom_by_sm, ref_seas):
    sm = args.target_start_month
    if sm not in hind_by_sm:
        raise ValueError(f"target start month {sm} is not in --start-months {args.start_months}")

    hind = hind_by_sm[sm]
    skill = skill_by_sm[sm]
    ref_anom = ref_by_sm[sm]
    hind_anom = hind_anom_by_sm[sm]
    ilead = select_lead(hind, args.target_season, args.target_lead_month)

    reg_hind = regional_mean(hind_anom, args.region)
    reg_ref = regional_mean(ref_anom, args.region)

    fig = plt.figure(figsize=(18, 13))

        # Panel a: skill map/scatter on native grid.
    ax1 = fig.add_subplot(2, 2, 1)

    corr = skill["corr"].isel(lead=ilead)
    lon, lat = get_lon_lat(corr)
    lon_plot = xr.where(lon > 180, lon - 360, lon)

    cc = corr.values

    # Handle both 1D lat/lon regular grid and 2D lat/lon mesh/grid.
    if lon_plot.ndim == 1 and lat.ndim == 1 and cc.ndim == 2:
        xx, yy = np.meshgrid(lon_plot.values, lat.values)
    else:
        xx = lon_plot.values
        yy = lat.values

    xx = np.asarray(xx).ravel()
    yy = np.asarray(yy).ravel()
    cc = np.asarray(cc).ravel()

    finite = np.isfinite(cc) & np.isfinite(xx) & np.isfinite(yy)

    sc = ax1.scatter(
        xx[finite], yy[finite], c=cc[finite],
        s=4, vmin=-1, vmax=1, cmap="RdBu_r"
    )

    lon_w, lon_e, lat_s, lat_n = args.region
    ax1.add_patch(Rectangle((lon_w, lat_s), lon_e - lon_w, lat_n - lat_s,
                            fill=False, edgecolor="blue", linewidth=2))
    ax1.set_xlim(-180, 180)
    ax1.set_ylim(-60, 85)
    ax1.set_xlabel("Longitude")
    ax1.set_ylabel("Latitude")
    ax1.set_title(f"a. {args.var} ACC, start {sm:02d}, {args.target_season} lead {args.target_lead_month}")
    fig.colorbar(sc, ax=ax1, label="ACC")

    # Panel b: regional time series.
    ax2 = fig.add_subplot(2, 2, 2)
    f = reg_hind.isel(lead=ilead).mean("member")
    fmin = reg_hind.isel(lead=ilead).min("member")
    fmax = reg_hind.isel(lead=ilead).max("member")
    o = reg_ref.isel(lead=ilead)

    # Plot by verification year rather than start year.
    valid_years = np.array([
        int(hind.valid_year.isel(lead=ilead).values) + (int(sy) - int(hind.start_year.values[0]))
        for sy in hind.start_year.values
    ])
    ax2.plot(valid_years, o, color="k", label="GP cycle-6 reference")
    ax2.plot(valid_years, f, color="b", label=f"S2S2D start {sm:02d} ensemble mean")
    ax2.fill_between(valid_years, fmin.values, fmax.values, color="b", alpha=0.2, label="member range")
    ax2.grid(True)
    ax2.legend()
    ax2.set_title(f"b. {args.var} ({args.region_name} {args.target_season}, lead {args.target_lead_month})")
    ax2.set_ylabel(args.var)
    ax2.set_xlabel("Verification year")

    # Panel c/d: target-season ACC and normalized RMSE vs lead month.
    ax3 = fig.add_subplot(2, 2, 3)
    ax4 = fig.add_subplot(2, 2, 4)
    colors = {5: "g", 11: "k", 2: "r", 8: "b"}
    for sm0 in args.start_months:
        h0 = hind_by_sm[sm0]
        reg_skill = corr_pval_rmse(
            regional_mean(hind_anom_by_sm[sm0], args.region),
            regional_mean(ref_by_sm[sm0], args.region),
        )
        is_target = h0["season"].values == args.target_season
        lm = h0["lead_month"].values[is_target]
        acc = reg_skill["corr"].values[is_target]
        rmse = reg_skill["rmse"].values[is_target]
        n = reg_skill["n"].values[is_target]

        lab = f"start {sm0:02d}"
        ax3.plot(lm, acc, marker="s", linestyle="none", color=colors.get(sm0, None), label=lab)
        ax4.plot(lm, rmse, marker="s", linestyle="none", color=colors.get(sm0, None), label=lab)

        # Add small n labels when samples are reduced by missing reference data.
        for x, y, nn in zip(lm, acc, n):
            if np.isfinite(y) and nn < (args.start_years[1] - args.start_years[0] + 1):
                ax3.text(x, y, f" n={int(nn)}", fontsize=8)

    # Persistence benchmark from GP reference regional target-season series.
    ref_series = regional_mean(ref_seas, args.region)
    ref_target = ref_series.where(
        ref_series.time.dt.month == SEASON_CENTER_MONTH[args.target_season], drop=True
    )
    ref_target = ref_target.sel(time=slice(str(args.skill_year0), str(args.skill_year1)))
    ref_dt = detrend_dim(ref_target, "time")
    ac1 = xr.corr(ref_dt, ref_dt.shift(time=1), dim="time").values.item()
    ac2 = xr.corr(ref_dt, ref_dt.shift(time=2), dim="time").values.item()
    ax3.axhline(ac1, color="0.4", linestyle="--", label="reference persistence lag-1")
    ax3.axhline(ac2, color="0.6", linestyle="--", label="reference persistence lag-2")

    ax3.set_title(f"c. {args.var} ({args.region_name} {args.target_season}) ACC")
    ax3.set_xlabel("Lead month")
    ax3.set_ylim(-0.3, 1.0)
    ax3.grid(True)
    ax3.legend()

    ax4.set_title(f"d. {args.var} ({args.region_name} {args.target_season}) normalized RMSE")
    ax4.set_xlabel("Lead month")
    ax4.set_ylim(0.0, 2.0)
    ax4.grid(True)
    ax4.legend()

    fig.tight_layout()
    fig.savefig(args.output, dpi=args.dpi)
    print(f"Wrote {args.output}")


def main():
    args = parse_args()
    args.ref_var = args.ref_var or args.var

    print("Reading/building S2S2D hindcasts...")
    hind_by_sm = build_or_read_hindcasts(args)

    print("Reading GP spinup cycle-6 reference...")
    ref_seas = read_gp_reference(args)

    hind_anom_by_sm = {}
    ref_anom_by_sm = {}
    skill_by_sm = {}

    for sm, hind in hind_by_sm.items():
        print(f"Computing anomalies and skill for start month {sm:02d}")

        hind_anom = remove_hindcast_drift(hind, args.climy0, args.climy1)
        ref_anom = ref_anomaly_at_valid_times(ref_seas, hind, args.climy0, args.climy1)

        # Restrict skill years if requested.
        # hind.valid_year can be either 1D: valid_year(lead)
        # or 2D: valid_year(start_year, lead), depending on how the cache was created.
        if "valid_year" in hind.coords:
            if "start_year" in hind.valid_year.dims:
                valid_years_2d = hind.valid_year
            else:
                first_start = int(hind.start_year.values[0])
                valid_years_2d = xr.DataArray(
                    np.array([
                        [
                            int(hind.valid_year.isel(lead=il).values)
                            + (int(sy) - first_start)
                            for il in range(hind.sizes["lead"])
                        ]
                        for sy in hind.start_year.values
                    ]),
                    dims=("start_year", "lead"),
                    coords={"start_year": hind.start_year, "lead": hind.lead},
                )

            keep = (valid_years_2d >= args.skill_year0) & (valid_years_2d <= args.skill_year1)
            hind_anom = hind_anom.where(keep)
            ref_anom = ref_anom.where(keep)

        hind_anom_by_sm[sm] = hind_anom
        ref_anom_by_sm[sm] = ref_anom
        skill_by_sm[sm] = corr_pval_rmse(hind_anom, ref_anom)

    plot(args, hind_by_sm, skill_by_sm, ref_anom_by_sm, hind_anom_by_sm, ref_seas)
   
if __name__ == "__main__":
    main()
