#!/usr/bin/env python3
"""
Dump Haritica's OWN bundled gene-set collections (the exact data the app's ORA
runs against) into per-collection TERM2GENE TSVs, so the R reference
(reference_enrichment.R) tests the identical gene sets — apples-to-apples, the
same way the DE positive control ran R DESeq2 on Haritica's own count matrix.

Source of truth: server/data/permissive_genesets/{organism}-202605.json.gz,
loaded via analyses/_ora (the same loader the production ORA uses).

For each collection short_code (go_bp, go_mf, go_cc, wiki, reactome, hallmark)
emits  refdata/term2gene_{organism}_{code}.tsv  with columns:
    term<TAB>gene
where `term` is the FULL bundle key string (e.g. "apoptotic process (GO:0006915)")
— identical to the dict key the app iterates over — so both engines key rows on
the exact same term identity and a join is unambiguous.

Also writes refdata/collections_{organism}.tsv  (code, identifier, n_terms,
n_genes_union) for the report's "datasets used" table.

Usage:  build_term2gene_from_bundle.py [organism=human] [out_dir]
"""
import sys
import os
import pathlib

# Make `import analyses._ora` work regardless of cwd.
REPO = pathlib.Path(__file__).resolve().parents[4]
sys.path.insert(0, str(REPO))

from analyses import _ora  # noqa: E402

ORGANISM = sys.argv[1] if len(sys.argv) > 1 else "human"
OUT_DIR = pathlib.Path(sys.argv[2]).resolve() if len(sys.argv) > 2 else (
    pathlib.Path(__file__).resolve().parent.parent / "refdata"
)
OUT_DIR.mkdir(parents=True, exist_ok=True)


def main() -> None:
    bundle = _ora._load_organism(ORGANISM)  # {code: {term: [genes]}}
    codes = list(bundle.keys())
    print(f"Organism {ORGANISM!r}: collections = {codes}")

    coll_rows = []
    for code in codes:
        gene_sets = bundle[code]
        try:
            identifier = _ora.database_identifier(ORGANISM, code)
        except Exception:
            identifier = code
        union = set()
        n_pairs = 0
        out_path = OUT_DIR / f"term2gene_{ORGANISM}_{code}.tsv"
        with open(out_path, "w", encoding="utf-8") as f:
            f.write("term\tgene\n")
            for term, genes in gene_sets.items():
                for g in genes:
                    if not g:
                        continue
                    f.write(f"{term}\t{g}\n")
                    union.add(g.strip().upper())
                    n_pairs += 1
        coll_rows.append((code, identifier, len(gene_sets), len(union)))
        print(
            f"  {code:9s} {identifier:24s} terms={len(gene_sets):6d} "
            f"genes(union)={len(union):6d} pairs={n_pairs:8d} -> {out_path.name}"
        )

    coll_tsv = OUT_DIR / f"collections_{ORGANISM}.tsv"
    with open(coll_tsv, "w", encoding="utf-8") as f:
        f.write("code\tidentifier\tn_terms\tn_genes_union\n")
        for code, identifier, n_terms, n_union in coll_rows:
            f.write(f"{code}\t{identifier}\t{n_terms}\t{n_union}\n")
    print(f"\nWrote collection summary -> {coll_tsv}")
    print(f"All TERM2GENE TSVs in {OUT_DIR}")


if __name__ == "__main__":
    main()
