#!/usr/bin/env Rscript
# ---------------------------------------------------------------------------
# Reference over-representation analysis with R/Bioconductor clusterProfiler,
# run on HARITICA'S OWN bundled gene sets (the TERM2GENE TSVs dumped by
# build_term2gene_from_bundle.py). This is the apples-to-apples external
# reference for the Enrichment positive control — the same role R DESeq2 played
# for the DE positive control: a different language + a different, canonical
# implementation, fed the identical gene sets and gene list.
#
# clusterProfiler::enricher() runs the hypergeometric test (== one-sided Fisher
# "greater", == Haritica's scipy.stats.fisher_exact) + Benjamini-Hochberg, so on
# identical sets/background the per-term p-values must match to numerical
# precision. We deliberately set pvalueCutoff=qvalueCutoff=1 to return ALL tested
# terms (no pre-filtering) so the concordance scatter spans the full distribution.
#
# Outputs:
#   reference_results.csv     all collections, columns: collection, term, GeneRatio,
#                             BgRatio, pvalue, p.adjust, qvalue, Count, geneID
#   figures/ref_dotplot.png ref_barplot.png ref_cnetplot.png ref_emapplot.png
#           ref_treeplot.png ref_upsetplot.png ref_heatplot.png   (from GO:BP)
#
# Usage:
#   Rscript reference_enrichment.R <query_csv> <refdata_dir> <out_dir> [de_results_csv]
# ---------------------------------------------------------------------------

suppressPackageStartupMessages({
  library(clusterProfiler)
  library(enrichplot)
  library(ggplot2)
})

args <- commandArgs(trailingOnly = TRUE)
query_csv   <- if (length(args) >= 1) args[1] else "tests/test_data/airway_significant_genes.csv"
refdata_dir <- if (length(args) >= 2) args[2] else "docs/validation/enrichment-positive-control/refdata"
out_dir     <- if (length(args) >= 3) args[3] else "docs/validation/enrichment-positive-control"
de_csv      <- if (length(args) >= 4) args[4] else file.path(dirname(query_csv), "airway_de_results.csv")

fig_dir <- file.path(out_dir, "figures")
dir.create(fig_dir, recursive = TRUE, showWarnings = FALSE)

MIN_GS <- 3L
MAX_GS <- 500L

read_tsv_fast <- function(path) {
  if (requireNamespace("data.table", quietly = TRUE)) {
    as.data.frame(data.table::fread(path, sep = "\t", header = TRUE,
                                    colClasses = "character", showProgress = FALSE))
  } else {
    read.delim(path, sep = "\t", header = TRUE, colClasses = "character",
               quote = "", stringsAsFactors = FALSE)
  }
}

# ---- query genes (uppercase to match Haritica's case-insensitive matching) ----
q <- read.csv(query_csv, stringsAsFactors = FALSE)
genes <- toupper(trimws(as.character(q[[1]])))
genes <- unique(genes[nzchar(genes) & !is.na(genes)])
cat(sprintf("Query genes: %d\n", length(genes)))

# ---- optional fold changes (for cnet/heat colouring) ----
fc <- NULL
if (file.exists(de_csv)) {
  de <- read.csv(de_csv, stringsAsFactors = FALSE)
  if (all(c("gene_symbol", "log2FoldChange") %in% names(de))) {
    de$gene_symbol <- toupper(trimws(de$gene_symbol))
    de <- de[!duplicated(de$gene_symbol), ]
    fc <- setNames(de$log2FoldChange, de$gene_symbol)
  }
}

# ---- collections present in refdata ----
t2g_files <- list.files(refdata_dir, pattern = "^term2gene_human_.*\\.tsv$", full.names = TRUE)
codes <- sub("^term2gene_human_(.*)\\.tsv$", "\\1", basename(t2g_files))
cat(sprintf("Collections: %s\n", paste(codes, collapse = ", ")))

all_res <- list()
go_bp_er <- NULL

for (i in seq_along(t2g_files)) {
  code <- codes[i]
  t2g <- read_tsv_fast(t2g_files[i])
  names(t2g)[1:2] <- c("term", "gene")
  t2g$gene <- toupper(trimws(t2g$gene))
  t2g <- t2g[nzchar(t2g$gene), c("term", "gene")]

  cat(sprintf("  enricher(%s): %d terms, %d pairs ... ", code,
              length(unique(t2g$term)), nrow(t2g)))
  er <- tryCatch(
    enricher(gene = genes, TERM2GENE = t2g,
             pvalueCutoff = 1, qvalueCutoff = 1, pAdjustMethod = "BH",
             minGSSize = MIN_GS, maxGSSize = MAX_GS),
    error = function(e) { cat(sprintf("ERROR %s\n", conditionMessage(e))); NULL }
  )
  if (is.null(er) || nrow(as.data.frame(er)) == 0) { cat("0 terms\n"); next }
  df <- as.data.frame(er)
  df$collection <- code
  all_res[[code]] <- df
  cat(sprintf("%d result terms\n", nrow(df)))
  if (code == "go_bp") go_bp_er <- er
}

ref <- do.call(rbind, all_res)
rownames(ref) <- NULL
# Canonical column subset + rename to match concordance.py expectations.
keep <- c("collection", "ID", "GeneRatio", "BgRatio", "pvalue", "p.adjust",
          "qvalue", "Count", "geneID")
keep <- keep[keep %in% names(ref)]
ref_out <- ref[, keep]
names(ref_out)[names(ref_out) == "ID"] <- "term"
out_csv <- file.path(out_dir, "reference_results.csv")
write.csv(ref_out, out_csv, row.names = FALSE)
cat(sprintf("Wrote %s (%d rows across %d collections)\n", out_csv, nrow(ref_out),
            length(unique(ref_out$collection))))

# ---------------------------------------------------------------------------
# enrichplot figures from GO:BP (the canonical collection; mirrors the NCI-BTEP
# clusterProfiler tutorial on this exact airway dataset). showCategory=20 to
# match the app's topN default.
# ---------------------------------------------------------------------------
save_png <- function(p, name, w = 1100, h = 900) {
  path <- file.path(fig_dir, name)
  ggsave(path, plot = p, width = w / 110, height = h / 110, dpi = 110,
         bg = "white", limitsize = FALSE)
  cat(sprintf("  figure -> %s\n", path))
}

if (!is.null(go_bp_er)) {
  er <- go_bp_er
  # Restrict to significant terms for the figures (cleaner; q<0.05).
  sig <- er
  sig@result <- er@result[er@result$qvalue < 0.05, , drop = FALSE]
  n_sig <- nrow(sig@result)
  cat(sprintf("GO:BP significant terms (q<0.05): %d\n", n_sig))

  tryCatch(save_png(dotplot(sig, showCategory = 20) +
                      ggtitle("clusterProfiler GO:BP (dotplot)"),
                    "ref_dotplot.png"),
           error = function(e) cat("dotplot failed:", conditionMessage(e), "\n"))

  tryCatch(save_png(barplot(sig, showCategory = 20) +
                      ggtitle("clusterProfiler GO:BP (barplot)"),
                    "ref_barplot.png"),
           error = function(e) cat("barplot failed:", conditionMessage(e), "\n"))

  tryCatch(save_png(cnetplot(sig, showCategory = 8, color.params = list(foldChange = fc)),
                    "ref_cnetplot.png", 1200, 1000),
           error = function(e) {
             tryCatch(save_png(cnetplot(sig, showCategory = 8, foldChange = fc),
                               "ref_cnetplot.png", 1200, 1000),
                      error = function(e2) cat("cnetplot failed:", conditionMessage(e2), "\n"))
           })

  # emap + tree need a term-similarity matrix; cap to top 40 for speed/legibility.
  sig_top <- sig
  ord <- order(sig@result$pvalue)
  sig_top@result <- sig@result[head(ord, 40), , drop = FALSE]
  ts <- tryCatch(pairwise_termsim(sig_top), error = function(e) NULL)
  if (!is.null(ts)) {
    tryCatch(save_png(emapplot(ts, showCategory = 30) +
                        ggtitle("clusterProfiler GO:BP (enrichment map)"),
                      "ref_emapplot.png", 1100, 1000),
             error = function(e) cat("emapplot failed:", conditionMessage(e), "\n"))
    tryCatch(save_png(treeplot(ts, showCategory = 30),
                      "ref_treeplot.png", 1300, 900),
             error = function(e) cat("treeplot failed:", conditionMessage(e), "\n"))
  }

  tryCatch(save_png(upsetplot(sig, n = 10), "ref_upsetplot.png", 1300, 800),
           error = function(e) cat("upsetplot failed:", conditionMessage(e), "\n"))

  tryCatch(save_png(heatplot(sig, showCategory = 15, foldChange = fc) +
                      theme(axis.text.x = element_text(size = 5)),
                    "ref_heatplot.png", 1500, 700),
           error = function(e) cat("heatplot failed:", conditionMessage(e), "\n"))
}

cat("DONE reference_enrichment.R\n")
