#!/usr/bin/env python3
"""
Concordance of Haritica VA's RNA-seq variant calls against the GIAB HG001
GRCh38 v4.2.1 DNA truth set, on chr20 + chr21.

For each aligner run in runs_manifest.json:
  - locate `called_variants.vcf.gz` (raw bcftools calls, already chr20/21 only),
  - apply the app's quality floor (QUAL>=20, INFO/DP>=10) to get the delivered set,
  - normalize (split multiallelics, left-align) with `bcftools norm -m -any -f ref`,
  - restrict to (GIAB high-confidence BED  ∩  chr20/21) SNVs,
  - `bcftools isec` against the identically-restricted GIAB truth,
  - PRECISION = TP / (TP + FP)   [headline; recall is intrinsically low for RNA-seq],
  - RNA-editing signature: A>G / T>C enrichment among false positives,
  - Ti/Tv + 12-class substitution spectrum via `bcftools stats` (independent of
    VA's own computation, which is cross-checked against variant_summary.json).

Chromosome-name harmonization: VA/Ensembl use bare `20`/`21`; GIAB uses UCSC
`chr20`/`chr21`. We rename GIAB -> Ensembl so the whole toolchain agrees.

Uses ONLY the bundled, GSL-free bcftools (no GPL, no VEP). Figures via matplotlib.

Usage: reference_concordance.py [root]
"""
import sys, os, json, subprocess, csv, re, pathlib, collections

ROOT = sys.argv[1] if len(sys.argv) > 1 else "/path/to/haritica-data/va_poscontrol"
BIN  = "/path/to/haritica/binaries/macos"
BCF  = f"{BIN}/bcftools"
REF  = f"{ROOT}/ref/GRCh38.fa"
GIAB_VCF = f"{ROOT}/giab/HG001_GRCh38_1_22_v4.2.1_benchmark.vcf.gz"
GIAB_BED = f"{ROOT}/giab/HG001_GRCh38_1_22_v4.2.1_benchmark.bed"
WORK = pathlib.Path(f"{ROOT}/concordance"); WORK.mkdir(parents=True, exist_ok=True)
FIG  = pathlib.Path("/path/to/haritica/docs/validation/variant-annotation-positive-control/figures")
FIG.mkdir(parents=True, exist_ok=True)

TS = {("A", "G"), ("G", "A"), ("C", "T"), ("T", "C")}   # transitions


def sh(cmd, **kw):
    return subprocess.run(cmd, check=True, capture_output=True, text=True, **kw)


def bcf(args, out=None):
    cmd = [BCF] + args
    if out:
        with open(out, "wb") as fh:
            p = subprocess.run(cmd, stdout=fh, stderr=subprocess.PIPE)
        if p.returncode:
            raise RuntimeError(f"bcftools {' '.join(args[:2])} failed: {p.stderr.decode()[:400]}")
    else:
        return sh(cmd)


def count_vcf(vcf):
    return int(sh([BCF, "view", "-H", vcf]).stdout.count("\n"))


def prepare_giab():
    """Rename GIAB chr20/21 -> 20/21, subset to 20,21, normalize, SNVs only,
    restrict to high-conf BED. Returns (giab_eval_vcf, ensembl_bed)."""
    eval_vcf = WORK / "giab.eval.vcf.gz"
    bed_ens  = WORK / "giab_highconf_20_21.ensembl.bed"
    if eval_vcf.exists() and bed_ens.exists():
        return str(eval_vcf), str(bed_ens)

    # chr-name map file
    rename = WORK / "giab2ensembl.txt"
    rename.write_text("chr20\t20\nchr21\t21\n")
    # BED -> Ensembl names, 20/21 only
    with open(GIAB_BED) as fh, open(bed_ens, "w") as out:
        for ln in fh:
            c = ln.split("\t", 1)
            if c and c[0] in ("chr20", "chr21"):
                out.write(ln.replace("chr20", "20", 1).replace("chr21", "21", 1))

    # refresh the GIAB index (downloaded .tbi is older than the data file)
    bcf(["index", "-f", "-t", GIAB_VCF])

    # VCF: subset chr20/21 -> rename -> norm -> SNVs -> restrict to BED
    tmp1 = WORK / "giab.2021.vcf.gz"
    bcf(["view", "-r", "chr20,chr21", "-Oz", "-o", str(tmp1), GIAB_VCF])
    bcf(["index", "-f", str(tmp1)])
    tmp2 = WORK / "giab.2021.ens.vcf.gz"
    bcf(["annotate", "--rename-chrs", str(rename), "-Oz", "-o", str(tmp2), str(tmp1)])
    bcf(["index", "-f", str(tmp2)])
    tmp3 = WORK / "giab.norm.vcf.gz"
    bcf(["norm", "-m", "-any", "-f", REF, "-Oz", "-o", str(tmp3), str(tmp2)])
    bcf(["index", "-f", str(tmp3)])
    bcf(["view", "-R", str(bed_ens), "-v", "snps", "-Oz", "-o", str(eval_vcf), str(tmp3)])
    bcf(["index", "-f", str(eval_vcf)])
    return str(eval_vcf), str(bed_ens)


def substitution_spectrum(vcf):
    """12-class ref>alt SNV spectrum from a normalized VCF."""
    spec = collections.Counter()
    out = sh([BCF, "view", "-H", "-v", "snps", vcf]).stdout
    for ln in out.splitlines():
        f = ln.split("\t")
        if len(f) < 5:
            continue
        r, a = f[3].upper(), f[4].upper()
        if len(r) == 1 and len(a) == 1 and r in "ACGT" and a in "ACGT" and r != a:
            spec[f"{r}>{a}"] += 1
    return spec


def edit_fraction(vcf):
    """Fraction of SNVs that are A>G or T>C (canonical ADAR RNA-edit signature)."""
    spec = substitution_spectrum(vcf)
    total = sum(spec.values())
    edits = spec.get("A>G", 0) + spec.get("T>C", 0)
    return (edits / total if total else 0.0), edits, total


def titv_from_stats(vcf):
    out = sh([BCF, "stats", vcf]).stdout
    for ln in out.splitlines():
        if ln.startswith("SN") and "ts/tv" in ln:
            continue
        m = re.match(r"^TSTV\t0\t(\d+)\t(\d+)\t([\d.]+)", ln)
        if m:
            return float(m.group(3)), int(m.group(1)), int(m.group(2))
    return None, None, None


def process_run(run, giab_eval, bed_ens):
    out_dir = pathlib.Path(run["output_dir"])
    called = out_dir / "called_variants.vcf.gz"
    name = run["name"]
    rec = {"name": name, "aligner": run["aligner"], "preset": run["preset"]}
    if not called.exists():
        rec["error"] = f"no called_variants.vcf.gz in {out_dir}"
        return rec

    rd = WORK / name; rd.mkdir(exist_ok=True)
    # index the raw calls if needed
    bcf(["index", "-f", str(called)])
    rec["raw_calls"] = count_vcf(str(called))

    # app's delivered set: QUAL>=20 & INFO/DP>=10, normalized, SNVs
    filt = rd / "app.filt.vcf.gz"
    bcf(["view", "-i", "QUAL>=20 && INFO/DP>=10", "-Oz", "-o", str(filt), str(called)])
    bcf(["index", "-f", str(filt)])
    norm = rd / "app.norm.vcf.gz"
    bcf(["norm", "-m", "-any", "-f", REF, "-Oz", "-o", str(norm), str(filt)])
    bcf(["index", "-f", str(norm)])

    # restrict to GIAB high-conf BED ∩ 20/21, SNVs
    appeval = rd / "app.eval.vcf.gz"
    bcf(["view", "-R", bed_ens, "-v", "snps", "-Oz", "-o", str(appeval), str(norm)])
    bcf(["index", "-f", str(appeval)])

    # isec
    isec = rd / "isec"; isec.mkdir(exist_ok=True)
    bcf(["isec", "-p", str(isec), str(appeval), giab_eval])
    fp = count_vcf(str(isec / "0000.vcf"))   # app-only
    fn = count_vcf(str(isec / "0001.vcf"))   # truth-only (region recall denom)
    tp = count_vcf(str(isec / "0002.vcf"))   # shared (app side)

    rec["eval_snvs"]   = tp + fp
    rec["TP"] = tp; rec["FP"] = fp; rec["FN"] = fn
    rec["precision"]   = round(tp / (tp + fp), 4) if (tp + fp) else None
    rec["region_recall"] = round(tp / (tp + fn), 4) if (tp + fn) else None

    # RNA-editing signature: A>G/T>C fraction among FP vs TP
    fp_frac, fp_edits, fp_tot = edit_fraction(str(isec / "0000.vcf"))
    tp_frac, tp_edits, tp_tot = edit_fraction(str(isec / "0002.vcf"))
    rec["fp_AG_TC_frac"] = round(fp_frac, 4)
    rec["tp_AG_TC_frac"] = round(tp_frac, 4)

    # precision EXCLUDING canonical RNA-edit FPs (drop A>G/T>C app-only calls)
    tp_plus_nonedit_fp = tp + (fp - fp_edits)
    rec["precision_excl_edits"] = round(tp / tp_plus_nonedit_fp, 4) if tp_plus_nonedit_fp else None

    # Ti/Tv + spectrum on the normalized delivered set (chr20/21, pre-BED)
    titv, ti, tv = titv_from_stats(str(norm))
    rec["titv_bcftools"] = titv
    rec["spectrum"] = dict(substitution_spectrum(str(norm)))

    # cross-check VA's self-reported Ti/Tv
    vs = out_dir / "variant_summary.json"
    if vs.exists():
        try:
            j = json.loads(vs.read_text())
            rec["titv_app"] = j.get("ti_tv_ratio")
            rec["app_total_variants"] = j.get("total_variants") or j.get("passed_qc")
        except Exception:
            pass
    rec["elapsed_s"] = run.get("elapsed_s")
    return rec


def make_figures(results):
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    import numpy as np

    ok = [r for r in results if r.get("precision") is not None]
    if not ok:
        print("no runs with precision; skipping figures"); return
    labels = [r["name"].replace("minimap2_", "mm2 ").replace("_", " ") for r in ok]

    # 1. precision bar (with/without RNA-edit FPs)
    fig, ax = plt.subplots(figsize=(7.5, 5))
    x = np.arange(len(ok)); w = 0.38
    ax.bar(x - w/2, [r["precision"] for r in ok], w, label="precision", color="#0090c1")
    ax.bar(x + w/2, [r.get("precision_excl_edits") or 0 for r in ok], w,
           label="precision excl. RNA-edit FPs", color="#38aecc")
    for i, r in enumerate(ok):
        ax.text(i - w/2, r["precision"] + 0.01, f"{r['precision']:.3f}", ha="center", fontsize=8)
    ax.set_xticks(x); ax.set_xticklabels(labels, rotation=12)
    ax.set_ylim(0, 1.05); ax.set_ylabel("SNV precision vs GIAB (chr20/21, high-conf)")
    ax.set_title("VA RNA-seq SNV precision vs GIAB HG001 truth, by aligner")
    ax.legend(); ax.grid(axis="y", alpha=0.3)
    fig.tight_layout(); fig.savefig(FIG / "concordance_precision.png", dpi=140); plt.close(fig)

    # 2. substitution spectrum (per aligner) with transitions highlighted
    classes = ["A>C", "A>G", "A>T", "C>A", "C>G", "C>T",
               "G>A", "G>C", "G>T", "T>A", "T>C", "T>G"]
    fig, ax = plt.subplots(figsize=(10, 5))
    x = np.arange(len(classes)); w = 0.8 / len(ok)
    for i, r in enumerate(ok):
        sp = r["spectrum"]; tot = sum(sp.values()) or 1
        vals = [sp.get(c, 0) / tot for c in classes]
        ax.bar(x + i*w - 0.4 + w/2, vals, w, label=r["name"].replace("minimap2_", "mm2 "))
    for j, c in enumerate(classes):
        if (c[0], c[2]) in TS:
            ax.axvspan(j-0.45, j+0.45, color="#fff3cd", alpha=0.4, zorder=0)
    ax.set_xticks(x); ax.set_xticklabels(classes, rotation=45)
    ax.set_ylabel("fraction of SNVs"); ax.set_title("12-class substitution spectrum (yellow = transitions; A>G/T>C = RNA editing)")
    ax.legend(); fig.tight_layout(); fig.savefig(FIG / "titv_spectrum.png", dpi=140); plt.close(fig)

    # 3. RNA-edit signature: A>G+T>C fraction, FP vs TP
    fig, ax = plt.subplots(figsize=(7.5, 5))
    ax.bar(x[:len(ok)] - w if False else np.arange(len(ok)) - 0.2, [r["fp_AG_TC_frac"] for r in ok], 0.4,
           label="false positives (app-only)", color="#c1440e")
    ax.bar(np.arange(len(ok)) + 0.2, [r["tp_AG_TC_frac"] for r in ok], 0.4,
           label="true positives (shared w/ GIAB)", color="#046e8f")
    ax.set_xticks(np.arange(len(ok))); ax.set_xticklabels(labels, rotation=12)
    ax.set_ylabel("A>G + T>C fraction"); ax.set_ylim(0, 1.0)
    ax.set_title("RNA-editing signature: ADAR (A>G/T>C) enriched among false positives")
    ax.legend(); ax.grid(axis="y", alpha=0.3)
    fig.tight_layout(); fig.savefig(FIG / "rna_edit_signature.png", dpi=140); plt.close(fig)
    print("figures ->", FIG)


def main():
    manifest = json.loads(pathlib.Path(f"{ROOT}/results/runs_manifest.json").read_text())
    manifest = [r for r in manifest if r.get("status") == "completed"]
    if not manifest:
        print("no completed runs in manifest"); sys.exit(1)

    print("Preparing GIAB truth (rename, norm, SNVs, high-conf BED)...")
    giab_eval, bed_ens = prepare_giab()
    print("  GIAB eval SNVs (chr20/21, high-conf):", count_vcf(giab_eval))

    results = []
    for run in manifest:
        print(f"\n=== concordance: {run['name']} ===")
        rec = process_run(run, giab_eval, bed_ens)
        print("  ", {k: rec[k] for k in ("precision", "TP", "FP", "FN",
              "precision_excl_edits", "fp_AG_TC_frac", "titv_bcftools", "titv_app") if k in rec})
        results.append(rec)

    out = WORK / "concordance_summary.json"
    out.write_text(json.dumps(results, indent=2))
    # CSV table
    cols = ["name", "aligner", "preset", "raw_calls", "eval_snvs", "TP", "FP", "FN",
            "precision", "region_recall", "precision_excl_edits",
            "fp_AG_TC_frac", "tp_AG_TC_frac", "titv_bcftools", "titv_app", "elapsed_s"]
    with open(FIG.parent / "aligner_comparison.csv", "w", newline="") as fh:
        w = csv.DictWriter(fh, fieldnames=cols, extrasaction="ignore")
        w.writeheader()
        for r in results:
            w.writerow(r)
    make_figures(results)
    print("\nwrote", out, "and aligner_comparison.csv")


if __name__ == "__main__":
    main()
