#!/usr/bin/env python3
"""
Concordance: Haritica's ORA engine vs R/Bioconductor clusterProfiler::enricher(),
on the airway / Himes-2014 significant genes, over Haritica's OWN bundled gene
sets. Mirrors the DE positive control's "pyDESeq2 vs R DESeq2 on the same counts".

Emits figures/concordance.png and prints the headline numbers:
  - per-collection term-set Jaccard, p-value Pearson r & Spearman rho, exact
    overlap-count agreement, top-20 rank agreement
  - app==engine check (the app's results CSV == analyses/_ora output)
  - biology recovery of the published glucocorticoid/steroid/ECM/vasculature axes

Usage:
  concordance.py <reference_results.csv> <out_png> [app_gobp_results.csv]
"""
import sys, os, math
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr

REPO = "/path/to/haritica"
sys.path.insert(0, REPO)
from analyses import _ora

REF_CSV = sys.argv[1] if len(sys.argv) > 1 else f"{REPO}/docs/validation/enrichment-positive-control/reference_results.csv"
OUT_PNG = sys.argv[2] if len(sys.argv) > 2 else f"{REPO}/docs/validation/enrichment-positive-control/figures/concordance.png"
APP_CSV = sys.argv[3] if len(sys.argv) > 3 else ""
GENES_CSV = f"{REPO}/tests/test_data/airway_significant_genes.csv"

MIN_OVERLAP, MAX_SET = 3, 500
COLLECTIONS = ["go_bp", "go_mf", "go_cc", "wiki", "reactome", "hallmark"]


def load_genes():
    s = pd.read_csv(GENES_CSV).iloc[:, 0].dropna().astype(str).str.strip()
    return [g for g in s.tolist() if g]


def haritica_ora(genes, code):
    df = _ora.run_ora_across_collections(
        genes, organism="human", short_codes=[code],
        min_overlap=MIN_OVERLAP, max_set_size=MAX_SET,
    )
    return df


def main():
    genes = load_genes()
    ref_all = pd.read_csv(REF_CSV)
    ref_all["Count"] = pd.to_numeric(ref_all["Count"], errors="coerce")

    print(f"Airway query genes: {len(genes)}\n")
    print(f"{'collection':10s} {'har(k>=3)':>9s} {'ref(k>=3)':>9s} {'shared':>7s} "
          f"{'Jaccard':>8s} {'p_Pearson':>10s} {'p_Spearman':>11s} {'cnt_exact':>10s}")

    rows = []
    gobp_join = None
    for code in COLLECTIONS:
        har = haritica_ora(genes, code)
        if har.empty:
            continue
        ref = ref_all[ref_all["collection"] == code].copy()
        # clusterProfiler returns all k>=1 terms; restrict to k>=MIN_OVERLAP to
        # match Haritica's min_overlap filter for an apples-to-apples set.
        ref = ref[ref["Count"] >= MIN_OVERLAP]

        har_terms = set(har["term"])
        ref_terms = set(ref["term"])
        shared = har_terms & ref_terms
        union = har_terms | ref_terms
        jacc = len(shared) / len(union) if union else float("nan")

        j = har.merge(ref[["term", "pvalue", "Count"]], on="term",
                      suffixes=("_har", "_ref"))
        if len(j) >= 3:
            lh = -np.log10(np.clip(j["pvalue_har"].to_numpy(float), 1e-300, 1))
            lr = -np.log10(np.clip(j["pvalue_ref"].to_numpy(float), 1e-300, 1))
            pear = pearsonr(lh, lr)[0]
            spear = spearmanr(lh, lr)[0]
            cnt_exact = float((j["gene_count"].astype(int) == j["Count"].astype(int)).mean())
        else:
            pear = spear = cnt_exact = float("nan")

        print(f"{code:10s} {len(har_terms):9d} {len(ref_terms):9d} {len(shared):7d} "
              f"{jacc:8.4f} {pear:10.4f} {spear:11.4f} {cnt_exact*100:9.1f}%")
        rows.append(dict(collection=code, n_har=len(har_terms), n_ref=len(ref_terms),
                         shared=len(shared), jaccard=jacc, p_pearson=pear,
                         p_spearman=spear, count_exact=cnt_exact))
        if code == "go_bp":
            gobp_join = j

    summ = pd.DataFrame(rows)

    # ---- app == engine check ----------------------------------------------
    if APP_CSV and os.path.exists(APP_CSV):
        app = pd.read_csv(APP_CSV)
        har_gobp = haritica_ora(genes, "go_bp")
        m = app.merge(har_gobp[["term", "pvalue"]], on="term", suffixes=("_app", "_eng"))
        if len(m):
            dp = np.abs(m["pvalue_app"].to_numpy(float) - m["pvalue_eng"].to_numpy(float)).max()
            print(f"\napp==engine: {len(m)} shared GO:BP terms, max|Δp| = {dp:.2e} "
                  f"(app CSV vs analyses/_ora)")

    # ---- biology recovery (GO:BP significant q<0.05) ----------------------
    har_gobp = haritica_ora(genes, "go_bp")
    sig = har_gobp[har_gobp["qvalue"] < 0.05]
    blob = " ".join(sig["term"].str.lower())
    axes = {
        "hormone/steroid/glucocorticoid": ["hormone", "steroid", "glucocorticoid"],
        "extracellular matrix": ["extracellular matrix", "extracellular structure"],
        "vasculature/circulatory": ["vasculature", "vascular", "circulatory", "blood vessel", "angiogenesis"],
        "muscle/fat/differentiation": ["muscle", "fat cell", "adipo", "differentiation"],
    }
    print(f"\nBiology recovery (GO:BP, {len(sig)} significant terms q<0.05):")
    for name, kws in axes.items():
        hit = [k for k in kws if k in blob]
        print(f"  {'OK ' if hit else '-- '}{name:32s} {hit}")

    # ---- scatter (GO:BP full distribution) --------------------------------
    if gobp_join is not None and len(gobp_join) >= 3:
        lh = -np.log10(np.clip(gobp_join["pvalue_har"].to_numpy(float), 1e-300, 1))
        lr = -np.log10(np.clip(gobp_join["pvalue_ref"].to_numpy(float), 1e-300, 1))
        r = pearsonr(lh, lr)[0]
        fig, ax = plt.subplots(figsize=(6.2, 6.0), dpi=130)
        lim = max(lh.max(), lr.max()) * 1.03
        ax.plot([0, lim], [0, lim], color="#94a3b8", lw=1, ls="--", zorder=1)
        ax.scatter(lr, lh, s=10, alpha=0.45, color="#0090c1", edgecolors="none", zorder=2)
        ax.set_xlabel("clusterProfiler  −log₁₀(p)", fontsize=11)
        ax.set_ylabel("Haritica ORA  −log₁₀(p)", fontsize=11)
        ax.set_title(f"GO:BP per-term p-value concordance\n"
                     f"n = {len(gobp_join)} terms · Pearson r = {r:.4f}", fontsize=12)
        ax.set_xlim(0, lim); ax.set_ylim(0, lim)
        ax.spines[["top", "right"]].set_visible(False)
        for s in ("left", "bottom"):
            ax.spines[s].set_color("#022f40")
        fig.tight_layout()
        fig.savefig(OUT_PNG, bbox_inches="tight", facecolor="white")
        print(f"\nWrote {OUT_PNG}  (GO:BP n={len(gobp_join)}, Pearson r={r:.4f})")

    summ.to_csv(os.path.join(os.path.dirname(OUT_PNG), "..", "concordance_summary.csv"), index=False)


if __name__ == "__main__":
    main()
