#!/usr/bin/env Rscript
# ─────────────────────────────────────────────────────────────────────────────
# Canonical R/Bioconductor DESeq2 reference for the DE positive-control report.
#
# Runs on Haritica's OWN count matrix (counts_matrix.csv produced by the app's
# FASTQ→minimap2→featureCounts path) so the comparison isolates the DE *engine*
# (R DESeq2 vs the app's pyDESeq2) — exactly mirroring the WGCNA report, which
# ran the original R WGCNA package on the identical expression CSV.
#
# Design:  ~condition,  reference level = control (DMSO)  ⇒ log2FC = CPT / DMSO
# Cutoffs: padj < 0.05  &  |log2FC| > 1   (= FC ≥ 2, the published filter)
#
# Usage:  Rscript reference_deseq2.R <counts_matrix.csv> <outdir>
# Emits:  <outdir>/r_deseq2_results.csv        (gene-level stats for concordance)
#         <outdir>/ref_volcano.png ref_ma.png ref_dispersion.png ref_pca.png
#         <outdir>/ref_heatmap_samples.png ref_heatmap_genes.png
# ─────────────────────────────────────────────────────────────────────────────
suppressPackageStartupMessages({
  library(DESeq2); library(ggplot2); library(pheatmap)
})

args <- commandArgs(trailingOnly = TRUE)
counts_csv <- args[[1]]
outdir     <- args[[2]]
dir.create(outdir, showWarnings = FALSE, recursive = TRUE)

# ---- load count matrix (gene index + per-sample columns) --------------------
cnt <- read.csv(counts_csv, row.names = 1, check.names = FALSE)
# The app's counts_matrix.csv is: gene index + seqnames,pos metadata + the
# per-sample count columns (Control_1..N, Treatment_1..N). Keep ONLY the sample
# count columns — selecting by the condition pattern drops seqnames/pos and any
# other annotation columns (numeric `pos` would otherwise look like a sample).
num <- cnt[, vapply(cnt, is.numeric, logical(1)), drop = FALSE]
samp <- grepl("control|dmso|wildtype|wt|ctrl|treat|cpt|mutant|mut",
              tolower(colnames(num)))
if (!any(samp)) {  # fallback: drop known metadata, keep the rest
  samp <- !grepl("^pos$|^position$|seqnames|length|gene_name|chrom|start|end|strand|biotype|symbol",
                 tolower(colnames(num)))
}
counts <- as.matrix(round(num[, samp, drop = FALSE]))
mode(counts) <- "integer"
cat("Count columns used:", paste(colnames(counts), collapse = ", "), "\n")

# ---- assign condition from sample-column names ------------------------------
cn <- tolower(colnames(counts))
is_ctrl  <- grepl("control|dmso|wildtype|^wt|_wt|ctrl", cn)
is_treat <- grepl("treat|cpt|mutant|^mut|_mut", cn)
cond <- ifelse(is_ctrl, "control", ifelse(is_treat, "treatment", NA))
if (any(is.na(cond))) {
  # Fallback: first half = control, second half = treatment (pipeline order)
  n <- ncol(counts); cond <- c(rep("control", n %/% 2), rep("treatment", n - n %/% 2))
  message("WARN: condition inferred from column ORDER, not names: ",
          paste(colnames(counts), collapse = ", "))
}
coldata <- data.frame(condition = factor(cond, levels = c("control", "treatment")),
                      row.names = colnames(counts))
cat("Samples:\n"); print(coldata)

# ---- DESeq2 (median-of-ratios + Wald + BH) ----------------------------------
dds <- DESeqDataSetFromMatrix(counts, coldata, design = ~condition)
dds <- dds[rowSums(counts(dds)) >= 10, ]          # mirror app min-count filter
dds <- DESeq(dds)
res <- results(dds, contrast = c("condition", "treatment", "control"),
               alpha = 0.05)                       # log2FC = CPT / DMSO
resdf <- as.data.frame(res); resdf$gene <- rownames(resdf)
resdf <- resdf[, c("gene", "baseMean", "log2FoldChange", "lfcSE",
                   "stat", "pvalue", "padj")]
write.csv(resdf, file.path(outdir, "r_deseq2_results.csv"), row.names = FALSE)

sig  <- !is.na(resdf$padj) & resdf$padj < 0.05 & abs(resdf$log2FoldChange) > 1
n_up   <- sum(sig & resdf$log2FoldChange > 0, na.rm = TRUE)
n_down <- sum(sig & resdf$log2FoldChange < 0, na.rm = TRUE)
cat(sprintf("R DESeq2 DEGs (padj<0.05 & |log2FC|>1): %d total | %d up | %d down\n",
            n_up + n_down, n_up, n_down))

# ---- volcano ----------------------------------------------------------------
vd <- resdf; vd$nlp <- -log10(pmax(vd$padj, .Machine$double.xmin))
vd$cls <- ifelse(sig & vd$log2FoldChange > 0, "Up",
          ifelse(sig & vd$log2FoldChange < 0, "Down", "NS"))
png(file.path(outdir, "ref_volcano.png"), 1100, 900, res = 150)
print(ggplot(vd[!is.na(vd$padj), ], aes(log2FoldChange, nlp, color = cls)) +
  geom_point(alpha = .5, size = .8) +
  scale_color_manual(values = c(Up = "#c0392b", Down = "#2471a3", NS = "grey70")) +
  geom_vline(xintercept = c(-1, 1), linetype = 2, color = "grey40") +
  geom_hline(yintercept = -log10(0.05), linetype = 2, color = "grey40") +
  labs(title = "R DESeq2 — Volcano (CPT vs DMSO)", x = "log2 fold change",
       y = "-log10 adjusted p", color = NULL) + theme_bw())
dev.off()

# ---- MA ----------------------------------------------------------------------
png(file.path(outdir, "ref_ma.png"), 1100, 900, res = 150)
plotMA(res, alpha = 0.05, main = "R DESeq2 — MA plot (CPT vs DMSO)",
       ylim = c(-1, 1) * max(4, quantile(abs(resdf$log2FoldChange), .999, na.rm = TRUE)))
dev.off()

# ---- dispersion (the DESeq2 plotDispEsts panel the app's Dispersion tab mirrors)
png(file.path(outdir, "ref_dispersion.png"), 1100, 900, res = 150)
plotDispEsts(dds, main = "R DESeq2 — Dispersion estimates")
dev.off()

# ---- PCA + sample-distance heatmap ------------------------------------------
vsd <- tryCatch(vst(dds, blind = TRUE), error = function(e) rlog(dds, blind = TRUE))
png(file.path(outdir, "ref_pca.png"), 1100, 900, res = 150)
print(plotPCA(vsd, intgroup = "condition") + theme_bw() +
        ggtitle("R DESeq2 — PCA (VST)"))
dev.off()

sampleDist <- dist(t(assay(vsd)))
dm <- as.matrix(sampleDist)
png(file.path(outdir, "ref_heatmap_samples.png"), 1000, 900, res = 150)
pheatmap(dm, clustering_distance_rows = sampleDist,
         clustering_distance_cols = sampleDist,
         main = "R DESeq2 — Sample-to-sample distance")
dev.off()

# ---- top-DEG expression heatmap (mirrors paper Fig 1A) ----------------------
topg <- head(resdf$gene[order(resdf$padj)], 50)
mat  <- assay(vsd)[topg, , drop = FALSE]
mat  <- mat - rowMeans(mat)
png(file.path(outdir, "ref_heatmap_genes.png"), 900, 1100, res = 150)
pheatmap(mat, annotation_col = coldata, show_rownames = FALSE,
         main = "R DESeq2 — Top-50 DEG expression (VST, row-centered)")
dev.off()

cat("DONE: wrote r_deseq2_results.csv + ref_*.png to ", outdir, "\n")
