#!/usr/bin/env python3
"""
Annotation validation: compare Haritica VA's SnpEff consequence/impact calls to
an INDEPENDENT annotator (`bcftools csq`) on the *identical* called VCF. This
isolates the annotation step from variant calling.

For each completed aligner run:
  - parse SnpEff ANN from `snpeff_annotated.vcf` (consequence + impact per variant),
  - run `bcftools csq` on `called_variants.vcf.gz` (Ensembl GFF3, GSL-free bcftools),
  - map both vocabularies to coarse SO classes and measure per-variant agreement,
  - report the SnpEff impact distribution (HIGH/MODERATE/LOW/MODIFIER).

bcftools csq has strict Ensembl-GFF3 expectations; if it fails we still emit the
SnpEff distribution and say so (best-effort independent cross-check).

Usage: annotation_concordance.py [root] [run_name]
  run_name defaults to the first completed minimap2 run (annotation is
  aligner-independent, so one run suffices for the SnpEff↔csq comparison).
"""
import sys, os, json, subprocess, re, pathlib, collections

ROOT = sys.argv[1] if len(sys.argv) > 1 else "/path/to/haritica-data/va_poscontrol"
ONLY = sys.argv[2] if len(sys.argv) > 2 else None
BIN  = "/path/to/haritica/binaries/macos"
BCF  = f"{BIN}/bcftools"
REF  = f"{ROOT}/ref/GRCh38.fa"
GFF  = f"{ROOT}/ref/Homo_sapiens.GRCh38.110.gff3.gz"
WORK = pathlib.Path(f"{ROOT}/annotation"); WORK.mkdir(parents=True, exist_ok=True)
FIG  = pathlib.Path("/path/to/haritica/docs/validation/variant-annotation-positive-control/figures")
FIG.mkdir(parents=True, exist_ok=True)

IMPACTS = ["HIGH", "MODERATE", "LOW", "MODIFIER"]


def coarse(term):
    """Map any SnpEff/csq consequence string to a coarse SO class (robust to
    the two tools' vocabulary differences)."""
    t = term.lower()
    if "frameshift" in t: return "frameshift"
    if "stop_gain" in t or "stop_lost" in t or "start_lost" in t: return "stop/start"
    if "splice" in t: return "splice"
    if "missense" in t or "non_synonymous" in t or "inframe" in t: return "missense/inframe"
    if "synonymous" in t: return "synonymous"
    if "utr" in t or "_prime" in t: return "utr"
    if "intron" in t: return "intron"
    if "upstream" in t: return "upstream"
    if "downstream" in t: return "downstream"
    if "intergenic" in t: return "intergenic"
    if "non_coding" in t or "nc_transcript" in t or "non_coding_transcript" in t: return "non_coding"
    return "other"


def parse_snpeff(vcf):
    """chrom:pos:ref:alt -> (consequence_term, impact) from the first ANN entry."""
    res = {}
    impact_counts = collections.Counter()
    opener = subprocess.Popen([BCF, "view", "-H", vcf], stdout=subprocess.PIPE, text=True)
    for ln in opener.stdout:
        f = ln.rstrip("\n").split("\t")
        if len(f) < 8:
            continue
        chrom, pos, _id, ref, alt = f[0], f[1], f[2], f[3], f[4]
        info = f[7]
        m = re.search(r"ANN=([^;\t]+)", info)
        if not m:
            continue
        first = m.group(1).split(",")[0].split("|")
        if len(first) < 3:
            continue
        cons = first[1].split("&")[0]
        impact = first[2]
        for a in alt.split(","):
            res[f"{chrom}:{pos}:{ref}:{a}"] = (cons, impact)
        impact_counts[impact] += 1
    opener.wait()
    return res, impact_counts


def run_csq(called_vcf):
    """chrom:pos:ref:alt -> consequence_term from bcftools csq BCSQ. Best-effort."""
    out_vcf = WORK / "csq.vcf"
    # -l local (unphased) consequences; --force to tolerate minor GFF3 quirks.
    cmd = [BCF, "csq", "-f", REF, "-g", GFF, "-l", "--force",
           "-O", "v", "-o", str(out_vcf), called_vcf]
    p = subprocess.run(cmd, capture_output=True, text=True)
    if p.returncode != 0:
        return None, p.stderr[-500:]
    res = {}
    with open(out_vcf) as fh:
        for ln in fh:
            if ln.startswith("#"):
                continue
            f = ln.rstrip("\n").split("\t")
            if len(f) < 8:
                continue
            chrom, pos, ref, alt = f[0], f[1], f[3], f[4]
            m = re.search(r"BCSQ=([^;\t]+)", f[7])
            if not m:
                continue
            # BCSQ format: consequence|gene|transcript|biotype|strand|aa|nt ...
            cons = m.group(1).split(",")[0].split("|")[0].lstrip("@").split("&")[0]
            for a in alt.split(","):
                res[f"{chrom}:{pos}:{ref}:{a}"] = cons
    return res, None


def make_figures(impact_counts, agreement):
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    import numpy as np

    # impact distribution
    fig, ax = plt.subplots(figsize=(6.5, 4.5))
    vals = [impact_counts.get(i, 0) for i in IMPACTS]
    colors = ["#c1440e", "#e08a1e", "#0090c1", "#9bb7c4"]
    ax.bar(IMPACTS, vals, color=colors)
    for i, v in enumerate(vals):
        ax.text(i, v, str(v), ha="center", va="bottom", fontsize=9)
    ax.set_ylabel("variant count"); ax.set_title("SnpEff impact distribution (GRCh38.99, chr20/21)")
    fig.tight_layout(); fig.savefig(FIG / "impact_distribution.png", dpi=140); plt.close(fig)

    # SnpEff vs csq coarse-class agreement matrix
    if agreement and agreement.get("matrix"):
        classes = agreement["classes"]
        M = np.array(agreement["matrix"], dtype=float)
        Mn = M / (M.sum(axis=1, keepdims=True) + 1e-9)
        fig, ax = plt.subplots(figsize=(8, 7))
        im = ax.imshow(Mn, cmap="Blues", vmin=0, vmax=1)
        ax.set_xticks(range(len(classes))); ax.set_xticklabels(classes, rotation=45, ha="right")
        ax.set_yticks(range(len(classes))); ax.set_yticklabels(classes)
        ax.set_xlabel("bcftools csq class"); ax.set_ylabel("SnpEff class")
        for i in range(len(classes)):
            for j in range(len(classes)):
                if M[i, j]:
                    ax.text(j, i, int(M[i, j]), ha="center", va="center",
                            fontsize=7, color="black" if Mn[i, j] < 0.6 else "white")
        ax.set_title(f"SnpEff vs bcftools csq coarse-class agreement\n"
                     f"(diagonal agreement = {agreement['agreement']:.1%} of {agreement['n_shared']} shared variants)")
        fig.colorbar(im, fraction=0.046)
        fig.tight_layout(); fig.savefig(FIG / "annotation_agreement.png", dpi=140); plt.close(fig)


def main():
    manifest = json.loads(pathlib.Path(f"{ROOT}/results/runs_manifest.json").read_text())
    completed = [r for r in manifest if r.get("status") == "completed"]
    if ONLY:
        completed = [r for r in completed if r["name"] == ONLY]
    else:
        mm = [r for r in completed if r["aligner"] == "minimap2"]
        completed = mm[:1] or completed[:1]
    if not completed:
        print("no completed run found"); sys.exit(1)
    run = completed[0]
    out_dir = pathlib.Path(run["output_dir"])
    print(f"annotation concordance on run: {run['name']}")

    snpeff_vcf = out_dir / "snpeff_annotated.vcf"
    called_vcf = out_dir / "called_variants.vcf.gz"
    result = {"run": run["name"]}

    if not snpeff_vcf.exists():
        print("  no snpeff_annotated.vcf — SnpEff may have been skipped");
        result["snpeff"] = "missing"
    se_map, impact_counts = ({}, collections.Counter())
    if snpeff_vcf.exists():
        se_map, impact_counts = parse_snpeff(str(snpeff_vcf))
        result["snpeff_impact_distribution"] = dict(impact_counts)
        result["snpeff_n"] = len(se_map)
        print("  SnpEff impacts:", dict(impact_counts))

    csq_map, err = run_csq(str(called_vcf))
    agreement = None
    if csq_map is None:
        print("  bcftools csq failed (best-effort):", (err or "").strip()[:200])
        result["csq"] = f"failed: {(err or '').strip()[:200]}"
    else:
        result["csq_n"] = len(csq_map)
        shared = [k for k in se_map if k in csq_map]
        classes = sorted({coarse(se_map[k][0]) for k in shared} |
                         {coarse(csq_map[k]) for k in shared})
        idx = {c: i for i, c in enumerate(classes)}
        M = [[0] * len(classes) for _ in classes]
        agree = 0
        for k in shared:
            a, b = coarse(se_map[k][0]), coarse(csq_map[k])
            M[idx[a]][idx[b]] += 1
            if a == b:
                agree += 1
        agreement = {
            "n_shared": len(shared),
            "agreement": round(agree / len(shared), 4) if shared else None,
            "classes": classes, "matrix": M,
        }
        # Per-coarse-class tallies over the shared variants — feeds the report's
        # "Haritica (SnpEff) vs Reference (bcftools csq)" comparison table.
        # Row sums = SnpEff per-class counts; column sums = csq per-class counts.
        se_cls  = {c: sum(M[idx[c]]) for c in classes}
        csq_cls = {c: sum(M[i][idx[c]] for i in range(len(classes))) for c in classes}
        diag    = {c: M[idx[c]][idx[c]] for c in classes}
        result["classes"] = classes
        result["snpeff_class_counts"] = se_cls
        result["csq_class_counts"] = csq_cls
        result["class_diagonal"] = diag
        result["snpeff_vs_csq"] = {k: agreement[k] for k in ("n_shared", "agreement")}
        if agreement["agreement"] is not None:
            print(f"  SnpEff↔csq coarse agreement: {agreement['agreement']:.1%} of {len(shared)} shared")
        else:
            print(f"  SnpEff↔csq: no shared variants to compare "
                  f"(SnpEff n={len(se_map)}, csq n={len(csq_map)})")

    make_figures(impact_counts, agreement)
    (WORK / "annotation_agreement.json").write_text(json.dumps(result, indent=2))
    print("wrote", WORK / "annotation_agreement.json")


if __name__ == "__main__":
    main()
