#!/usr/bin/env python3
"""
Concordance: Haritica pyDESeq2 (de_results.csv) vs canonical R DESeq2
(r_deseq2_results.csv) on the IDENTICAL count matrix, plus the published anchor
(GSE142440 wild-type arm: 671 up / 2,223 down at FC>=2 & FDR<=0.05).

Emits a JSON summary to stdout and a scatter PNG (ours vs R log2FC).

Usage: concordance.py <ours_de_results.csv> <r_deseq2_results.csv> <out_png> [out_json]
"""
import sys, json
import numpy as np
import pandas as pd

PAD = 0.05      # padj cutoff
LFC = 1.0       # |log2FC| cutoff (FC >= 2)
PUB_UP, PUB_DOWN = 671, 2223   # published WT-arm DESeq2 DEG counts

ours_csv, r_csv, out_png = sys.argv[1], sys.argv[2], sys.argv[3]
out_json = sys.argv[4] if len(sys.argv) > 4 else None

ours = pd.read_csv(ours_csv)
# the app writes the gene column as index_label='gene'
if "gene" not in ours.columns:
    ours = ours.rename(columns={ours.columns[0]: "gene"})
r = pd.read_csv(r_csv)

def sig_counts(df, lfc_col, padj_col):
    s = df[(df[padj_col] < PAD) & (df[lfc_col].abs() > LFC)]
    up = int((s[lfc_col] > 0).sum()); dn = int((s[lfc_col] < 0).sum())
    return up, dn, set(s["gene"])

o_up, o_dn, o_set = sig_counts(ours, "log2FoldChange", "padj")
r_up, r_dn, r_set = sig_counts(r, "log2FoldChange", "padj")

# join on gene for correlation
m = ours[["gene", "log2FoldChange", "padj"]].merge(
    r[["gene", "log2FoldChange", "padj"]], on="gene", suffixes=("_ours", "_r"))
m = m.dropna(subset=["log2FoldChange_ours", "log2FoldChange_r"])

lfc_r = float(np.corrcoef(m["log2FoldChange_ours"], m["log2FoldChange_r"])[0, 1])
# rank corr too (robust to outliers)
lfc_rho = float(m["log2FoldChange_ours"].corr(m["log2FoldChange_r"], method="spearman"))

mp = m.dropna(subset=["padj_ours", "padj_r"]).copy()
mp["nlp_ours"] = -np.log10(mp["padj_ours"].clip(lower=1e-300))
mp["nlp_r"] = -np.log10(mp["padj_r"].clip(lower=1e-300))
padj_r = float(np.corrcoef(mp["nlp_ours"], mp["nlp_r"])[0, 1]) if len(mp) > 2 else float("nan")

inter = o_set & r_set
union = o_set | r_set
jacc = len(inter) / len(union) if union else float("nan")

summary = {
    "n_genes_joined": int(len(m)),
    "ours": {"up": o_up, "down": o_dn, "total": o_up + o_dn},
    "r_deseq2": {"up": r_up, "down": r_dn, "total": r_up + r_dn},
    "published": {"up": PUB_UP, "down": PUB_DOWN, "total": PUB_UP + PUB_DOWN},
    "log2fc_pearson_r": round(lfc_r, 4),
    "log2fc_spearman_rho": round(lfc_rho, 4),
    "neglog10padj_pearson_r": round(padj_r, 4),
    "de_call_overlap_genes": len(inter),
    "de_call_jaccard": round(jacc, 4),
    "ours_only": len(o_set - r_set),
    "r_only": len(r_set - o_set),
    "down_skew_ours": round(o_dn / max(o_up + o_dn, 1), 3),
    "down_skew_r": round(r_dn / max(r_up + r_dn, 1), 3),
    "down_skew_published": round(PUB_DOWN / (PUB_UP + PUB_DOWN), 3),
}
print(json.dumps(summary, indent=2))
if out_json:
    with open(out_json, "w") as fh:
        json.dump(summary, fh, indent=2)

# ---- scatter: ours vs R log2FC ----------------------------------------------
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

both = set(m["gene"])  # already joined
m = m.copy()
m["agree"] = ((m["gene"].isin(o_set) & m["gene"].isin(r_set)))
fig, ax = plt.subplots(figsize=(6.4, 6.0), dpi=150)
ax.scatter(m["log2FoldChange_r"], m["log2FoldChange_ours"], s=4, alpha=0.25,
           c=np.where(m["agree"], "#c0392b", "#9aa7b1"))
lim = np.nanpercentile(np.abs(np.r_[m["log2FoldChange_r"], m["log2FoldChange_ours"]]), 99.5)
ax.plot([-lim, lim], [-lim, lim], "k--", lw=1, alpha=.6)
ax.set_xlim(-lim, lim); ax.set_ylim(-lim, lim)
ax.set_xlabel("R DESeq2  log2 fold change")
ax.set_ylabel("Haritica pyDESeq2  log2 fold change")
ax.set_title(f"log2FC concordance  (Pearson r = {lfc_r:.3f}, n = {len(m):,})")
fig.tight_layout(); fig.savefig(out_png)
print(f"[wrote] {out_png}", file=sys.stderr)
