#!/usr/bin/env python3
"""
pull.py — incrementally clone NERSC IRIS usage_* into Parquet partitions.

Layout:
  $USAGE_STORE/jobs/year=YYYY/month=MM/day=DD/part.parquet
  $USAGE_STORE/watermark.json   {"last_pulled_day": "YYYY-MM-DD", ...}

Usage:
  pull.py refresh                 # re-pull last N days (default 14)
  pull.py refresh --days 30
  pull.py day 2026-03-15
  pull.py range 2026-01-01 2026-01-31
"""

import argparse
import json
import os
import sys
import tempfile
from datetime import date, datetime, timedelta, timezone
from pathlib import Path
from urllib.parse import quote

import requests
import pyarrow as pa
import pyarrow.parquet as pq

# -------- config --------
STORE      = Path(os.environ.get("USAGE_STORE", Path.home() / "nersc-usage"))
KB_BASE    = "https://iris.nersc.gov/kb"
KB_VERSION = "9.1.2"
USER       = os.environ.get("ELASTIC_USER", "admin")
PW_FILE    = Path(os.environ.get("ELASTIC_PASSWORD_FILE", Path.home() / "kbpw"))
TZ         = "US/Pacific"
FETCH_SIZE = 5000
DEFAULT_REFRESH_DAYS = 14

PROXY_URL = f"{KB_BASE}/api/console/proxy?path={quote('/_sql?format=json', safe='')}&method=POST"
CLOSE_URL = f"{KB_BASE}/api/console/proxy?path={quote('/_sql/close', safe='')}&method=POST"
HEADERS = {
    "Content-Type": "application/json",
    "kbn-xsrf": "true",
    "kbn-version": KB_VERSION,
    "x-elastic-internal-origin": "Kibana",
    "elastic-api-version": "1",
}

def index_for(d: date) -> str:
    return f"usage_{d.year}"

# Columns we cannot SELECT because they're arrays in the mapping
# (ES SQL refuses: "Arrays (returned by [X]) are not supported").
EXCLUDE_COLS = {"nodes", "nodes.keyword"}

_cols_cache: dict[str, list[str]] = {}

def selectable_columns(s, idx: str) -> list[str]:
    if idx in _cols_cache:
        return _cols_cache[idx]
    resp = post(s, PROXY_URL, {"query": f'DESCRIBE "{idx}"'})
    cols = [row[0] for row in resp["rows"] if row[0] not in EXCLUDE_COLS]
    _cols_cache[idx] = cols
    return cols

# -------- HTTP --------
def make_session() -> requests.Session:
    s = requests.Session()
    s.auth = (USER, PW_FILE.read_text().strip())
    s.headers.update(HEADERS)
    return s

def post(s, url, body) -> dict:
    r = s.post(url, data=json.dumps(body), timeout=120)
    if not r.ok:
        sys.stderr.write(f"HTTP {r.status_code}: {r.text[:500]}\n")
        r.raise_for_status()
    resp = r.json()
    # ES SQL returns 200 with an error envelope on logical/SQL errors.
    if isinstance(resp, dict) and "error" in resp and "columns" not in resp and "rows" not in resp:
        raise RuntimeError(f"ES SQL error: {json.dumps(resp['error'])[:500]}")
    return resp

# -------- one-day pull --------
def build_day_sql(idx: str, cols: list[str], start: date, end: date) -> str:
    col_list = ", ".join(f'"{c}"' for c in cols)
    return (
        f'SELECT {col_list} FROM "{idx}" '
        f"WHERE \"Start\" >= CAST('{start.isoformat()}T00:00:00' AS DATETIME) "
        f"  AND \"Start\" <  CAST('{end.isoformat()}T00:00:00' AS DATETIME)"
    )

def pull_day(s, d: date) -> pa.Table:
    """All rows whose Start is in [d, d+1) Pacific."""
    idx = index_for(d)
    cols_to_pull = selectable_columns(s, idx)
    body = {
        "query": build_day_sql(idx, cols_to_pull, d, d + timedelta(days=1)),
        "time_zone": TZ,
        "fetch_size": FETCH_SIZE,
    }
    cols, rows = None, []
    last_cursor = None
    while True:
        resp = post(s, PROXY_URL, body)
        if cols is None:
            cols = [c["name"] for c in resp["columns"]]
        rows.extend(resp.get("rows", []))
        cursor = resp.get("cursor")
        if not cursor:
            break
        last_cursor = cursor
        body = {"cursor": cursor}
    if last_cursor:
        try:
            post(s, CLOSE_URL, {"cursor": last_cursor})
        except Exception:
            pass
    columns = {c: [r[i] for r in rows] for i, c in enumerate(cols or [])}
    return pa.table(columns)

# -------- atomic partition write --------
def partition_dir(d: date) -> Path:
    return STORE / "jobs" / f"year={d.year:04d}" / f"month={d.month:02d}" / f"day={d.day:02d}"

def write_day(d: date, table: pa.Table) -> int:
    out_dir = partition_dir(d)
    out_dir.mkdir(parents=True, exist_ok=True)
    final = out_dir / "part.parquet"
    with tempfile.NamedTemporaryFile(dir=out_dir, prefix=".part.", suffix=".parquet", delete=False) as tmp:
        pq.write_table(table, tmp.name, compression="zstd")
        os.replace(tmp.name, final)
    return table.num_rows

# -------- watermark --------
WM_FILE = STORE / "watermark.json"

def read_wm() -> dict:
    if WM_FILE.exists():
        return json.loads(WM_FILE.read_text())
    return {}

def write_wm(wm: dict) -> None:
    STORE.mkdir(parents=True, exist_ok=True)
    WM_FILE.write_text(json.dumps(wm, indent=2, default=str))

# -------- orchestration --------
def daterange(a: date, b: date):
    cur = a
    while cur <= b:
        yield cur
        cur += timedelta(days=1)

def cmd_day(args):
    d = date.fromisoformat(args.date)
    with make_session() as s:
        t = pull_day(s, d)
        n = write_day(d, t)
    print(f"{d}: {n} rows -> {partition_dir(d)/'part.parquet'}")

def cmd_range(args):
    a, b = date.fromisoformat(args.start), date.fromisoformat(args.end)
    wm = read_wm()
    total = 0
    with make_session() as s:
        for d in daterange(a, b):
            t = pull_day(s, d)
            n = write_day(d, t)
            print(f"{d}: {n} rows")
            total += n
            wm["last_pulled_day"] = d.isoformat()
            wm["last_pulled_at"]  = datetime.now(timezone.utc).isoformat()
            write_wm(wm)
    print(f"total: {total} rows across {(b - a).days + 1} day(s)")

def cmd_refresh(args):
    today = date.today()
    start = today - timedelta(days=args.days - 1)
    args.start, args.end = start.isoformat(), today.isoformat()
    cmd_range(args)

def main():
    p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    sub = p.add_subparsers(dest="cmd", required=True)
    sp = sub.add_parser("day");     sp.add_argument("date");                            sp.set_defaults(fn=cmd_day)
    sp = sub.add_parser("range");   sp.add_argument("start"); sp.add_argument("end");   sp.set_defaults(fn=cmd_range)
    sp = sub.add_parser("refresh"); sp.add_argument("--days", type=int, default=DEFAULT_REFRESH_DAYS); sp.set_defaults(fn=cmd_refresh)
    args = p.parse_args()
    args.fn(args)

if __name__ == "__main__":
    main()
