#!/usr/bin/env python3
"""
Cloud-FASTQ concordance: the CLOUD end-to-end run (raw FASTQ -> HISAT2 ->
featureCounts -> DESeq2 -> ORA, on AWS Batch) vs the gene-list run (published
airway DESeq2 genes -> ORA). Both go through the SAME bundled ORA engine, so
any divergence isolates the FASTQ->DE stage.

Shows: the same glucocorticoid/vascular pathways re-emerge from raw reads, the
per-term GO:BP p-values track the gene-list analysis, the two independent
DESeq2 gene lists overlap, and the canonical airway genes survive end-to-end.

Emits figures/deg_concordance.png + prints the headline numbers.
Usage: cloud_fastq_concordance.py <cloud_fastq_enrichment.json>
"""
import sys, json, os
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr

REPO = "/path/to/haritica"
sys.path.insert(0, REPO)
from analyses import _ora

CLOUD_JSON = sys.argv[1] if len(sys.argv) > 1 else "/tmp/cloud_fastq_enrichment.json"
OUT_PNG = f"{REPO}/docs/validation/enrichment-positive-control/figures/deg_concordance.png"
GENES_CSV = f"{REPO}/tests/test_data/airway_significant_genes.csv"
MIN_OVERLAP, MAX_SET = 3, 500

# database label (engine) -> short code
DB2CODE = {
    "GO_BP": "go_bp", "GO_MF": "go_mf", "GO_CC": "go_cc",
    "WikiPathways": "wiki", "Reactome": "reactome", "Hallmark": "hallmark",
}
CANON = ["CRISPLD2", "KLF15", "FKBP5", "DUSP1", "GPX3", "PER1", "ZBTB16",
         "SPARCL1", "TSC22D3", "CEACAM5", "STEAP2", "MAOA"]


def code_of(db_label: str) -> str:
    for pref, code in DB2CODE.items():
        if db_label.startswith(pref):
            return code
    return db_label


def main():
    # ---- cloud FASTQ enrichment --------------------------------------------
    cloud = json.load(open(CLOUD_JSON))
    c_terms = cloud["results"]
    for t in c_terms:
        t["code"] = code_of(t["database"])
    c_df = pd.DataFrame(c_terms)
    c_sig = c_df[c_df["qvalue"] < 0.05]
    # annotated DEGs that drove cloud enrichment (proxy for the 837 DEG list)
    cloud_genes = set()
    for g in c_df["genes"]:
        cloud_genes |= set(g if isinstance(g, list) else str(g).split(";"))
    cloud_genes = {x.strip() for x in cloud_genes if x and str(x).strip()}

    # ---- gene-list ORA on published airway DESeq2 genes --------------------
    genes = pd.read_csv(GENES_CSV).iloc[:, 0].dropna().astype(str).str.strip().tolist()
    genes = [g for g in genes if g]
    codes = ["go_bp", "go_mf", "go_cc", "wiki", "reactome", "hallmark"]
    p1 = _ora.run_ora_across_collections(
        genes, organism="human", short_codes=codes,
        min_overlap=MIN_OVERLAP, max_set_size=MAX_SET)
    p1["code"] = p1["term"].map(lambda _: None)  # fill per-collection below
    # _ora returns a 'database' col too; map it
    if "database" in p1.columns:
        p1["code"] = p1["database"].map(code_of)
    p1_sig = p1[p1["qvalue"] < 0.05]
    p1_genes = set()
    for g in p1["genes"]:
        p1_genes |= set(g if isinstance(g, list) else str(g).split(";"))
    p1_genes = {x.strip() for x in p1_genes if x and str(x).strip()}

    print(f"gene-list ({len(genes)} published DEGs): "
          f"{len(p1_sig)} sig terms, {len(p1_genes)} annotated genes")
    print(f"cloud FASTQ ({cloud['summary']['total_genes']} DEGs): "
          f"{len(c_sig)} sig terms, {len(cloud_genes)} annotated genes\n")

    # ---- DEG overlap (two independent DESeq2 runs) -------------------------
    inter = p1_genes & cloud_genes
    union = p1_genes | cloud_genes
    print("=== DEG (annotated) overlap: published-DESeq2 vs cloud-FASTQ-DESeq2 ===")
    print(f"  gene-list annotated genes : {len(p1_genes)}")
    print(f"  cloud FASTQ annotated genes : {len(cloud_genes)}")
    print(f"  shared                  : {len(inter)}  (Jaccard {len(inter)/len(union):.3f})")
    print(f"  canonical airway genes present:")
    for g in CANON:
        a = "P1" if g in p1_genes else "  "
        b = "P2" if g in cloud_genes else "  "
        mark = "OK" if (g in p1_genes and g in cloud_genes) else ".."
        print(f"    {mark} {g:10s} [{a} {b}]")

    # ---- per-collection term overlap gene-list <-> cloud FASTQ -------------
    print("\n=== Significant-term overlap (q<0.05) gene-list vs cloud FASTQ ===")
    print(f"  {'collection':10s} {'GL':>5s} {'CL':>5s} {'shared':>7s} {'Jaccard':>8s}")
    for code in codes:
        a = set(p1_sig[p1_sig["code"] == code]["term"])
        b = set(c_sig[c_sig["code"] == code]["term"])
        if not a and not b:
            continue
        sh = a & b
        un = a | b
        print(f"  {code:10s} {len(a):5d} {len(b):5d} {len(sh):7d} "
              f"{(len(sh)/len(un) if un else float('nan')):8.3f}")

    # ---- GO:BP per-term p-value concordance gene-list vs cloud FASTQ -------
    a = p1[p1["code"] == "go_bp"][["term", "pvalue"]].rename(columns={"pvalue": "p1"})
    b = c_df[c_df["code"] == "go_bp"][["term", "pvalue"]].rename(columns={"pvalue": "p2"})
    j = a.merge(b, on="term")
    print(f"\n=== GO:BP shared terms gene-list vs cloud FASTQ: {len(j)} ===")
    if len(j) >= 3:
        lh = -np.log10(np.clip(j["p1"].to_numpy(float), 1e-300, 1))
        lr = -np.log10(np.clip(j["p2"].to_numpy(float), 1e-300, 1))
        pear = pearsonr(lr, lh)[0]
        spear = spearmanr(lr, lh)[0]
        print(f"  -log10(p) Pearson r = {pear:.4f}  Spearman rho = {spear:.4f}")

        # biology recovery in the cloud FASTQ run
        blob = " ".join(c_sig[c_sig["code"] == "go_bp"]["term"].str.lower())
        axes = {
            "hormone/steroid/glucocorticoid": ["hormone", "steroid", "glucocorticoid"],
            "vasculature/circulatory": ["vasculature", "vascular", "angiogenesis", "blood vessel"],
            "extracellular matrix": ["extracellular matrix", "extracellular structure"],
            "muscle/contraction": ["muscle", "contraction", "actin"],
        }
        print("  cloud FASTQ biology recovery (GO:BP sig):")
        for nm, kws in axes.items():
            hit = [k for k in kws if k in blob]
            print(f"    {'OK ' if hit else '-- '}{nm:30s} {hit}")

        # figure
        fig, ax = plt.subplots(figsize=(6.2, 6.0), dpi=130)
        lim = max(lh.max(), lr.max()) * 1.03
        ax.plot([0, lim], [0, lim], color="#94a3b8", lw=1, ls="--", zorder=1)
        ax.scatter(lh, lr, s=11, alpha=0.45, color="#0090c1", edgecolors="none", zorder=2)
        ax.set_xlabel("gene-list  −log₁₀(p)", fontsize=11)
        ax.set_ylabel("cloud FASTQ  −log₁₀(p)", fontsize=11)
        ax.set_title(f"GO:BP enrichment: raw-FASTQ cloud vs published gene list\n"
                     f"n = {len(j)} shared terms · Pearson r = {pear:.3f}", fontsize=11.5)
        ax.set_xlim(0, lim); ax.set_ylim(0, lim)
        ax.spines[["top", "right"]].set_visible(False)
        for s in ("left", "bottom"):
            ax.spines[s].set_color("#022f40")
        fig.tight_layout()
        fig.savefig(OUT_PNG, bbox_inches="tight", facecolor="white")
        print(f"\nWrote {OUT_PNG}")


if __name__ == "__main__":
    main()
