#!/usr/bin/env python3
"""Stacked daily Perlmutter node-hour utilization by allocation_type.

Matches the reference NERSC plot: ALCC Awards / DOE Mission Science /
Director Reserve Allocations. "Overhead" and NULL allocation_type are
excluded from the bars (which is why bars don't reach 100%).

Run:
    module load python/3.13-26.1.0
    USAGE_STORE=$SCRATCH/nersc-usage python3 plot_util_stacked.py \
        --start 2026-01-21 --end 2026-05-14 \
        --out   $SCRATCH/nersc-usage/util-2026-stacked.png
"""
import argparse, os
from datetime import date

import duckdb
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

STORE = os.environ.get("USAGE_STORE", "/pscratch/sd/w/wbhimji/nersc-usage")

# Reference-plot colors (eyeballed from the user's screenshot)
COLORS = {
    "DOE Mission Science":          "#2e9d3e",  # green
    "ALCC Awards":                  "#3498db",  # blue
    "Director Reserve Allocations": "#f39c12",  # orange
}
# Stack order (bottom → top) matches the reference: blue, green, orange
STACK_ORDER = ["ALCC Awards", "DOE Mission Science", "Director Reserve Allocations"]


def main():
    ap = argparse.ArgumentParser(description=__doc__,
                                 formatter_class=argparse.RawDescriptionHelpFormatter)
    ap.add_argument("--start", required=True)
    ap.add_argument("--end",   required=True)
    ap.add_argument("--resource", choices=["gpu", "cpu"], default="gpu")
    ap.add_argument("--capacity", type=int, default=None,
                    help="Fixed node count for the denominator (defaults: gpu=1792, cpu=3072). "
                         "Set to 0 to use per-day MAX(ZoneSize) instead.")
    ap.add_argument("--out",   required=True)
    args = ap.parse_args()

    start = date.fromisoformat(args.start)
    end   = date.fromisoformat(args.end)
    zone, host = ("gpu", "perlmutter gpu") if args.resource == "gpu" else ("cpu", "perlmutter cpu")
    if args.capacity is None:
        capacity_override = 1792 if args.resource == "gpu" else 3072
    else:
        capacity_override = args.capacity

    con = duckdb.connect()
    con.execute(f"""
        CREATE OR REPLACE VIEW jobs AS
        SELECT * FROM read_parquet('{STORE}/jobs/**/*.parquet', hive_partitioning=true)
    """)
    con.execute("SET TimeZone = 'US/Pacific'")

    sql = f"""
    WITH RECURSIVE
    days(d) AS (
      SELECT DATE '{start}'
      UNION ALL SELECT d + INTERVAL 1 DAY FROM days WHERE d < DATE '{end}'
    ),
    bounds AS (
      SELECT d AS day,
             CAST(d AS TIMESTAMP)                AT TIME ZONE 'US/Pacific' AS win_start,
             CAST(d + INTERVAL 1 DAY AS TIMESTAMP) AT TIME ZONE 'US/Pacific' AS win_end
      FROM days
    ),
    overlapped AS (
      SELECT
        b.day,
        j."QOS"             AS qos,
        j."ReservationId"   AS rid,
        j."JobID"           AS jid,
        j."RawHours"        AS rawhours,
        j."ElapsedSecs"     AS elapsed,
        j."ZoneSize"        AS zonesize,
        j."allocation_type" AS atype,
        CAST(j."Start" AS TIMESTAMP) AT TIME ZONE 'UTC'                                     AS start_ts,
        CAST(j."Start" AS TIMESTAMP) AT TIME ZONE 'UTC' + INTERVAL (j."ElapsedSecs") SECOND AS end_ts,
        b.win_start, b.win_end
      FROM jobs j JOIN bounds b
        ON  CAST(j."Start" AS TIMESTAMP) AT TIME ZONE 'UTC' < b.win_end
        AND CAST(j."Start" AS TIMESTAMP) AT TIME ZONE 'UTC' + INTERVAL (j."ElapsedSecs") SECOND > b.win_start
      WHERE j."ZoneName" = '{zone}' AND j."hostname" = '{host}'
        AND j."ElapsedSecs" > 0
        AND j."allocation_pool" = 'DOE'
    ),
    prorated AS (
      SELECT day, qos, rid, jid, zonesize, rawhours, elapsed, atype,
             date_diff('second', GREATEST(start_ts, win_start), LEAST(end_ts, win_end)) AS overlap_secs,
             win_start, win_end
      FROM overlapped
    ),
    prorated_hours AS (
      SELECT day, qos, rid, jid, zonesize, atype,
             rawhours * overlap_secs / elapsed AS prorated_rh,
             win_start, win_end
      FROM prorated WHERE overlap_secs > 0
    ),
    -- Non-reservation: sum by allocation_type
    nonres AS (
      SELECT day, atype, SUM(prorated_rh) AS rh
      FROM prorated_hours WHERE qos <> 'RESERVE'
      GROUP BY day, atype
    ),
    -- Reservation: dedup per (rid|jid), keep max, then sum by atype
    res_groups AS (
      SELECT day, atype, COALESCE(NULLIF(rid, ''), jid) AS gk, MAX(prorated_rh) AS prorated_rh
      FROM prorated_hours WHERE qos = 'RESERVE'
      GROUP BY day, atype, gk
    ),
    res AS (
      SELECT day, atype, SUM(prorated_rh) AS rh FROM res_groups GROUP BY day, atype
    ),
    combined AS (
      SELECT day, atype, rh FROM nonres
      UNION ALL
      SELECT day, atype, rh FROM res
    ),
    by_type AS (
      SELECT day, atype, SUM(rh) AS rh FROM combined GROUP BY day, atype
    ),
    cap AS (
      SELECT day,
             CASE WHEN {capacity_override} > 0
                  THEN CAST({capacity_override} AS BIGINT)
                  ELSE MAX(zonesize) END AS capacity
      FROM prorated_hours GROUP BY day
    ),
    win AS (
      SELECT day, (epoch(win_end) - epoch(win_start)) / 3600.0 AS hours FROM bounds
    )
    SELECT b.day, COALESCE(t.atype, 'Other') AS atype,
           COALESCE(t.rh, 0)                                          AS rh,
           cap.capacity                                                AS capacity,
           win.hours                                                   AS day_hours,
           100.0 * COALESCE(t.rh, 0) / (cap.capacity * win.hours)      AS pct
    FROM bounds b
    LEFT JOIN by_type t USING (day)
    LEFT JOIN cap       USING (day)
    LEFT JOIN win       USING (day)
    ORDER BY b.day, atype
    """
    rows = con.execute(sql).fetchall()

    # Reshape: day -> {atype: pct}
    from collections import defaultdict
    per_day = defaultdict(dict)
    for day, atype, rh, cap, hrs, pct in rows:
        per_day[day][atype] = float(pct) if pct is not None else 0.0
    days = sorted(per_day.keys())

    # Plot stacked
    fig, ax = plt.subplots(figsize=(16, 6))
    bottom = [0.0] * len(days)
    for atype in STACK_ORDER:
        vals = [per_day[d].get(atype, 0.0) for d in days]
        ax.bar(days, vals, bottom=bottom, color=COLORS[atype], label=atype, width=0.85)
        bottom = [b + v for b, v in zip(bottom, vals)]

    ax.set_ylim(0, 100)
    ax.set_ylabel("Node Hour Utilization (%)")
    ax.set_xlabel("Date")
    ax.set_title(f"Daily Perlmutter {args.resource.upper()} Node-Hour Utilization by Allocation Type, "
                 f"{start} → {end}")
    ax.grid(axis="y", alpha=0.3)
    ax.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday=mdates.MO))
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%b %d"))
    ax.legend(loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=3, frameon=True)
    fig.autofmt_xdate()
    fig.tight_layout()
    fig.savefig(args.out, dpi=120)

    totals = [sum(per_day[d].get(a, 0.0) for a in STACK_ORDER) for d in days]
    print(f"wrote {args.out}")
    print(f"days: {len(days)}  mean stacked total: {sum(totals)/len(totals):.2f}%")
    for atype in STACK_ORDER:
        vals = [per_day[d].get(atype, 0.0) for d in days]
        print(f"  mean {atype:32s}: {sum(vals)/len(vals):.2f}%")


if __name__ == "__main__":
    main()
