# FBA_ldopa.py
# Complete L-DOPA FBA with Phase-1 (MazF OFF) and Phase-2 (MazF ON) + analyses.

import warnings
from pathlib import Path

import cobra
from cobra.io import read_sbml_model

# FVA import (Windows safe design below will set processes=1)
try:
    from cobra.flux_analysis import flux_variability_analysis as FVA
except Exception:
    from cobra.flux_analysis.variability import flux_variability_analysis as FVA


# --------------------------- User-configurable paths ---------------------------
SBML_PATH = Path("model") / "iJN1463.xml"   # change if your model file is elsewhere


# --------------------------- Helper functions ---------------------------------
def find_biomass_rxn(m: cobra.Model) -> cobra.Reaction:
    """Find a biomass reaction by name/id heuristics."""
    for r in m.reactions:
        if "biomass" in r.id.lower() or "biomass" in r.name.lower():
            return r
    raise RuntimeError("Biomass reaction not found; check your model.")


def find_atpm(m: cobra.Model) -> cobra.Reaction | None:
    """
    Find non-growth ATP maintenance reaction.
    Strategy: check common IDs, else detect ATP + H2O -> ADP + Pi + H+ with LB >= 0.
    """
    for rid in ["ATPM", "ATPM_c", "NGAM", "ATPhyd"]:
        try:
            return m.reactions.get_by_id(rid)
        except KeyError:
            pass

    # Try stoichiometric detection
    try:
        atp = m.metabolites.get_by_id("atp_c")
        adp = m.metabolites.get_by_id("adp_c")
        pi_ = m.metabolites.get_by_id("pi_c")
        h2o = m.metabolites.get_by_id("h2o_c")
        h   = m.metabolites.get_by_id("h_c")
    except KeyError:
        return None

    for r in m.reactions:
        mets = r.metabolites
        if all(x in mets for x in [atp, adp, pi_, h2o, h]):
            if mets[atp] < 0 and mets[h2o] < 0 and mets[adp] > 0 and mets[pi_] > 0 and mets[h] > 0:
                if r.lower_bound >= 0:
                    return r
    return None


def get_first_existing_met(m: cobra.Model, ids: list[str]) -> cobra.Metabolite:
    """Return the first metabolite in ids that exists; else raise."""
    for mid in ids:
        try:
            return m.metabolites.get_by_id(mid)
        except KeyError:
            continue
    raise KeyError(f"None of these metabolite IDs exist in the model: {ids}")


def ensure_exchange(m: cobra.Model, met_e: cobra.Metabolite, ex_id: str,
                    lb: float, ub: float) -> cobra.Reaction:
    """Add or update an exchange for met_e with given bounds."""
    try:
        ex = m.reactions.get_by_id(ex_id)
        # Replace stoichiometry to ensure it's correct (−1 on extracellular metabolite)
        ex.add_metabolites({met_e: -1.0}, combine=False)  # overwrite
        ex.lower_bound, ex.upper_bound = lb, ub
    except KeyError:
        ex = cobra.Reaction(ex_id)
        ex.name = f"{met_e.name} exchange"
        ex.lower_bound, ex.upper_bound = lb, ub
        ex.add_metabolites({met_e: -1.0})
        m.add_reactions([ex])
    return ex


def add_tpl_pathway(m: cobra.Model) -> None:
    """
    Add L-DOPA cytosolic metabolite, TPL reaction (catechol + pyruvate + NH4 -> L-DOPA + H2O),
    and L-DOPA extracellular metabolite + exchange + PMF-coupled exporter.
    Disable any 'free' LDOPA_tx if present.
    """
    # 1) L-DOPA (cytosolic)
    try:
        ldopa_c = m.metabolites.get_by_id("ldopa_c")
    except KeyError:
        ldopa_c = cobra.Metabolite("ldopa_c", formula="C9H11NO4", name="L-DOPA", compartment="c")
        m.add_metabolites([ldopa_c])

    # 2) Substrates (in cytosol) we expect in iJN1463
    catechol_c = m.metabolites.get_by_id("catechol_c")  # you used this earlier
    pyruvate_c = m.metabolites.get_by_id("pyr_c")
    ammonium_c = m.metabolites.get_by_id("nh4_c")
    water_c    = m.metabolites.get_by_id("h2o_c")

    # 3) TPL reaction
    if "TPL" not in [r.id for r in m.reactions]:
        tpl = cobra.Reaction("TPL")
        tpl.name = "Tyrosine phenol-lyase (L-DOPA synthesis)"
        tpl.lower_bound, tpl.upper_bound = 0.0, 1000.0
        tpl.add_metabolites({
            catechol_c: -1.0,
            pyruvate_c: -1.0,
            ammonium_c: -1.0,
            ldopa_c:    1.0,
            water_c:    1.0
        })
        tpl.gene_reaction_rule = "TPL_gene"  # dummy gene for knockouts if needed
        m.add_reactions([tpl])

    # 4) L-DOPA extracellular + exchange + PMF-coupled exporter
    try:
        ldopa_e = m.metabolites.get_by_id("ldopa_e")
    except KeyError:
        ldopa_e = cobra.Metabolite("ldopa_e", formula="C9H11NO4", name="L-DOPA (e)", compartment="e")
        m.add_metabolites([ldopa_e])

    # H+ in cytosol/extracellular (varies by model)
    h_c = get_first_existing_met(m, ["h_c", "h_p", "h_i"])
    h_e = get_first_existing_met(m, ["h_e", "h_p_e", "h_ext"])

    # PMF exporter (ldopa_c + h_c -> ldopa_e + h_e)
    if "LDOPAtex_PMF" not in [r.id for r in m.reactions]:
        eff = cobra.Reaction("LDOPAtex_PMF")
        eff.name = "L-DOPA export (proton-coupled)"
        eff.lower_bound, eff.upper_bound = 0.0, 1000.0
        eff.add_metabolites({ldopa_c: -1.0, h_c: -1.0, ldopa_e: 1.0, h_e: 1.0})
        m.add_reactions([eff])

    # Disable any old free transporter
    if "LDOPA_tx" in [r.id for r in m.reactions]:
        m.reactions.get_by_id("LDOPA_tx").bounds = (0.0, 0.0)

    # Exchange: secretion only (no re-uptake)
    ensure_exchange(m, ldopa_e, "EX_ldopa_e", lb=0.0, ub=1000.0)


def set_medium(m: cobra.Model,
               catechol_lb=-0.821, pyruvate_lb=-4.0, nh4_lb=-10.0,
               allow_glucose=False, glucose_lb=-10.0) -> None:
    """Set uptake bounds for catechol, pyruvate, NH4+; optionally allow glucose uptake."""
    # Catechol external + exchange
    try:
        catechol_e = m.metabolites.get_by_id("catechol_e")
    except KeyError:
        catechol_e = cobra.Metabolite("catechol_e", formula="C6H6O2", name="Catechol (e)", compartment="e")
        m.add_metabolites([catechol_e])
    ensure_exchange(m, catechol_e, "EX_catechol_e", lb=catechol_lb, ub=0.0)

    # Pyruvate & NH4 exchanges exist in iJN1463
    m.reactions.get_by_id("EX_pyr_e").bounds = (pyruvate_lb, 0.0)
    m.reactions.get_by_id("EX_nh4_e").bounds = (nh4_lb, 0.0)

    # Optional glucose
    # if allow_glucose:
       # m.reactions.get_by_id("EX_glc__D_e").bounds = (glucose_lb, 0.0)
     #else:
        # make sure glucose uptake is OFF (no negative LB)
         #try:
            # m.reactions.get_by_id("EX_glc__D_e").lower_bound = 0.0
         #except KeyError:
            # pass


def increase_atpm_for_burden(m: cobra.Model,
                             tpl_burden=2.0,
                             mazf_off_burden=0.20) -> float:
    """
    Increase NGAM/ATPM LB to reflect pathway expression burden (TPL) and baseline circuit burden (MazF OFF).
    Returns new ATPM lower bound.
    """
    atpm = find_atpm(m)
    if atpm is None:
        raise RuntimeError("ATPM-like maintenance reaction not found; cannot set burden.")
    base_ngam = max(0.0, float(atpm.lower_bound))
    new_lb = base_ngam + tpl_burden + mazf_off_burden
    if atpm.upper_bound < new_lb:
        atpm.upper_bound = new_lb
    atpm.lower_bound = new_lb
    return atpm.lower_bound


def pareto_tradeoff(m: cobra.Model, biomass_rxn: cobra.Reaction,
                    ex_ldopa_id="EX_ldopa_e") -> list[tuple[float, float]]:
    """
    Fix biomass to fractions of its max and maximize L-DOPA secretion to get a trade-off curve.
    Returns list of (fraction_of_mu_max, ldopa_flux).
    """
    # Maximize growth
    m.objective = biomass_rxn
    sol = m.optimize()
    mu_max = float(sol.objective_value)

    fractions = [1.0, 0.90, 0.80, 0.70, 0.60, 0.50, 0.40, 0.30, 0.20, 0.10, 0.0]
    results = []

    saved_bounds = biomass_rxn.bounds
    ldopa_ex = m.reactions.get_by_id(ex_ldopa_id)

    for frac in fractions:
        target = frac * mu_max
        biomass_rxn.bounds = (target, target)
        m.objective = ldopa_ex
        s = m.optimize()
        val = float(s.objective_value) if s.status == "optimal" else 0.0
        results.append((frac, val))

    biomass_rxn.bounds = saved_bounds
    m.objective = biomass_rxn
    return results


def knockout_catechol_consumers(m: cobra.Model,
                                exclude_ids=("TPL", "EX_catechol_e")) -> list[str]:
    """Knock out all reactions that CONSUME catechol_c except those in exclude_ids."""
    cat_c = m.metabolites.get_by_id("catechol_c")
    removed = []
    for rxn in list(cat_c.reactions):
        coeff = rxn.metabolites.get(cat_c, 0.0)
        if coeff < 0 and rxn.id not in exclude_ids:
            rxn.knock_out()
            removed.append(rxn.id)
    return removed


def run_phase2_with_residual_growth(
    model: cobra.Model,
    biomass_rxn: cobra.Reaction,
    mu_star: float,
    translation_residual=0.10,   # 10% translation left; use 0.25 for “4× decrease”
    atpm_scale=1.50,             # +50% maintenance vs Phase‑1
    residual_mu_frac=0.05,       # Phase‑2 growth = 5% of μ*
    fix_growth=True,             # fix μ exactly or cap μ
    tpl_rxn_id="TPL",
    ldopa_ex_id="EX_ldopa_e"
):
    """
    Build a Phase‑2 (MazF ON) model that still grows a bit:
      - biomass flux set to residual_mu_frac * μ* (fixed or capped)
      - ATPM LB scaled by atpm_scale (e.g., 1.50 == +50%)
      - TPL UB capped to translation_residual × pre-MazF achievable TPL flux
      - objective = L-DOPA secretion
    Returns (phase2_model, metrics_dict)
    """

    # Pre‑MazF achievable TPL under current medium
    pre = model.copy()
    pre.objective = pre.reactions.get_by_id(ldopa_ex_id)
    pre_sol = pre.optimize()
    pre_tpl_flux = float(pre_sol.fluxes.get(tpl_rxn_id, 0.0))
    if pre_tpl_flux < 1e-9:
        pre_tpl_flux = float(pre.reactions.get_by_id(tpl_rxn_id).upper_bound)

    # Phase‑2 copy
    m2 = model.copy()

    # Residual growth
    residual_mu = residual_mu_frac * mu_star
    b2 = m2.reactions.get_by_id(biomass_rxn.id)
    if fix_growth:
        b2.bounds = (residual_mu, residual_mu)
    else:
        b2.bounds = (0.0, residual_mu)

    # Maintenance ATP: scale LB by relative factor
    atpm2 = find_atpm(m2)
    if atpm2 is None:
        raise RuntimeError("ATPM-like maintenance reaction not found in Phase‑2 model.")
    old_lb = float(atpm2.lower_bound)
    new_lb = old_lb * atpm_scale
    if atpm2.upper_bound < new_lb:
        atpm2.upper_bound = new_lb
    atpm2.lower_bound = new_lb

    # Translation suppression: cap TPL to fraction of *achievable* pre‑MazF flux
    tpl2 = m2.reactions.get_by_id(tpl_rxn_id)
    tpl_cap = max(0.0, translation_residual * pre_tpl_flux)
    tpl2.upper_bound = min(tpl2.upper_bound, tpl_cap)

    # Phase‑1 baseline (growth) for comparison
    m1 = model.copy()
    m1.objective = m1.reactions.get_by_id(biomass_rxn.id)
    sol1 = m1.optimize()

    # Phase‑2 objective: maximize L‑DOPA
    m2.objective = m2.reactions.get_by_id(ldopa_ex_id)
    sol2 = m2.optimize()
    if fix_growth and sol2.status != "optimal":
        # relax to upper cap if equality infeasible
        b2.bounds = (0.0, residual_mu)
        sol2 = m2.optimize()

    metrics = {
        "params": {
            "translation_residual": translation_residual,
            "atpm_scale": atpm_scale,
            "pre_tpl_flux": pre_tpl_flux,
            "tpl_cap": tpl2.upper_bound,
            "phase2_ATPM_lb": atpm2.lower_bound,
            "residual_mu_frac": residual_mu_frac,
            "residual_mu": residual_mu,
            "fix_growth": fix_growth
        },
        "phase1": {
            "growth": float(sol1.objective_value),
            "ldopa": float(sol1.fluxes.get(ldopa_ex_id, 0.0)),
            "tpl":   float(sol1.fluxes.get(tpl_rxn_id, 0.0))
        },
        "phase2": {
            "growth": float(sol2.fluxes.get(biomass_rxn.id, 0.0)),
            "ldopa": float(sol2.fluxes.get(ldopa_ex_id, 0.0)),
            "tpl":   float(sol2.fluxes.get(tpl_rxn_id, 0.0))
        }
    }
    return m2, metrics


# --------------------------- Main pipeline -------------------------------------
def main():
    warnings.filterwarnings("ignore", category=UserWarning)

    # 1) Load model
    if not SBML_PATH.exists():
        raise FileNotFoundError(f"Cannot find SBML at {SBML_PATH.resolve()}")
    model = read_sbml_model(str(SBML_PATH))
    print(f"Loaded model '{model.name}' with {len(model.reactions)} reactions, "
          f"{len(model.metabolites)} metabolites, and {len(model.genes)} genes.")

    # 2) Add L-DOPA pathway (+ exporter + exchange)
    add_tpl_pathway(model)

    # 3) Medium: catechol, pyruvate, NH4 uptake ON; glucose OFF by default
    set_medium(model, catechol_lb=-0.821, pyruvate_lb=-4.0, nh4_lb=-10.0, allow_glucose=False)

    # 4) Increase ATP maintenance for burdens present in Phase‑1 (MazF OFF)
    atpm_lb = increase_atpm_for_burden(model, tpl_burden=2.0, mazf_off_burden=0.20)
    print(f"ATPM (Phase‑1) lower bound set to: {atpm_lb:.2f} mmol ATP/gDW/h")

    # 5) Find biomass reaction
    biomass_rxn = find_biomass_rxn(model)

    # 6) Phase‑1 (MazF OFF): maximize growth; show μ*
    model.objective = biomass_rxn
    sol_phase1 = model.optimize()
    mu_star = float(sol_phase1.objective_value)
    print(f"Max biomass growth rate (Phase‑1) = {mu_star:.4f} 1/h")

    # 7) Trade‑off curve: L‑DOPA vs Biomass
    trade = pareto_tradeoff(model, biomass_rxn, ex_ldopa_id="EX_ldopa_e")
    print("Biomass (% of max)\tL-DOPA flux (mmol/gDW/h)")
    for frac, flux in trade:
        print(f"{int(frac*100):>3d}%\t\t\t{flux:.4f}")

    # Plot tradeoff
    try:
        import matplotlib.pyplot as plt
        xs = [f * 100 for f, _ in trade]
        ys = [v for _, v in trade]
        plt.figure()
        plt.plot(xs, ys, marker="o")
        plt.gca().invert_xaxis()
        plt.xlabel("Biomass (% of maximum)")
        plt.ylabel("L-DOPA production (mmol/gDW/h)")
        plt.title("Biomass vs L-DOPA production")
        plt.grid(True, alpha=0.3)
        plt.savefig("pareto_ldopa.png", dpi=200, bbox_inches="tight")
        print("Saved plot to pareto_ldopa.png")
        try: plt.show()
        except Exception: pass
    except ModuleNotFoundError:
        print("matplotlib not installed; skipping Pareto plot.")

    # 8) FVA at μ*: is L-DOPA optional at max growth?
    fva = FVA(model, reaction_list=["EX_ldopa_e", "TPL"], fraction_of_optimum=1.0, processes=1)
    ld_min = float(fva.loc["EX_ldopa_e", "minimum"]); ld_max = float(fva.loc["EX_ldopa_e", "maximum"])
    tpl_min = float(fva.loc["TPL", "minimum"]); tpl_max = float(fva.loc["TPL", "maximum"])
    print(f"\nFVA at optimal biomass (growth = {mu_star:.3f} 1/h):")
    print(f"  L-DOPA exchange flux range: {ld_min:.4f} .. {ld_max:.4f} mmol/gDW/h")
    print(f"  TPL reaction flux range:    {tpl_min:.4f} .. {tpl_max:.4f} mmol/gDW/h")

    # 9) Catechol consumer knockouts (optional analysis)
    k_model = model.copy()
    knocked = knockout_catechol_consumers(k_model, exclude_ids=("TPL", "EX_catechol_e"))
    print("Knocked out catechol-consuming reactions (excluding TPL):", knocked)

    # Re-compute μ* after knockouts
    k_model.objective = find_biomass_rxn(k_model)
    mu_knock = float(k_model.optimize().objective_value)

    # Force a small catechol uptake to see if L‑DOPA becomes essential
    ex_cat_k = k_model.reactions.get_by_id("EX_catechol_e")
    ex_cat_k.bounds = (-0.10, -0.10)  # force some catechol in
    k_model.objective = find_biomass_rxn(k_model)
    mu_forced = float(k_model.optimize().objective_value)

    fva_kn = FVA(k_model, reaction_list=["EX_ldopa_e", "TPL"], fraction_of_optimum=1.0, processes=1)
    print(f"\nKnockout + forced catechol uptake: growth = {mu_forced:.3f} 1/h")
    print("  L-DOPA range:", fva_kn.loc["EX_ldopa_e", ["minimum", "maximum"]].to_dict())
    print("  TPL range:   ", fva_kn.loc["TPL",        ["minimum", "maximum"]].to_dict())

    # 10) Phase‑2 (MazF ON) with residual growth; compare to Phase‑1
    phase2_model, summary = run_phase2_with_residual_growth(
        model=model,
        biomass_rxn=biomass_rxn,
        mu_star=mu_star,
        translation_residual=0.10,   # 10% translation left; set 0.25 for 25% left
        atpm_scale=1.50,             # +50% ATPM vs Phase‑1
        residual_mu_frac=0.05,       # Phase‑2 μ = 5% of μ*
        fix_growth=True,             # fix to that μ (use False to cap instead)
        tpl_rxn_id="TPL",
        ldopa_ex_id="EX_ldopa_e"
    )

    print("\n=== Phase‑2 (MazF ON) summary ===")
    print("Params:", summary["params"])
    print(f"Phase‑1: μ={summary['phase1']['growth']:.3f}, "
          f"EX_ldopa={summary['phase1']['ldopa']:.3f}, TPL={summary['phase1']['tpl']:.3f}")
    print(f"Phase‑2: μ={summary['phase2']['growth']:.3f}, "
          f"EX_ldopa={summary['phase2']['ldopa']:.3f}, TPL={summary['phase2']['tpl']:.3f}")

    # 11) Bar plot Phase‑1 vs Phase‑2
    try:
        import matplotlib.pyplot as plt
        import numpy as np

        labels = ["Biomass (1/h)", "L‑DOPA (mmol/gDW/h)", "TPL (mmol/gDW/h)"]
        p1 = [summary["phase1"]["growth"], summary["phase1"]["ldopa"], summary["phase1"]["tpl"]]
        p2 = [summary["phase2"]["growth"], summary["phase2"]["ldopa"], summary["phase2"]["tpl"]]

        x = np.arange(len(labels)); w = 0.38
        fig, ax = plt.subplots()
        ax.bar(x - w/2, p1, w, label="Phase‑1 (MazF OFF)")
        ax.bar(x + w/2, p2, w, label="Phase‑2 (MazF ON)")
        ax.set_xticks(x); ax.set_xticklabels(labels, rotation=15, ha="right")
        ax.set_ylabel("Flux"); ax.set_title("Phase‑1 vs Phase‑2 (MazF)")
        ax.legend(); plt.tight_layout()
        plt.savefig("phase1_vs_phase2_with_growth.png", dpi=200)
        print("Saved plot to phase1_vs_phase2_with_growth.png")
        try: plt.show()
        except Exception: pass
    except ModuleNotFoundError:
        print("matplotlib not installed; skipping Phase‑1 vs Phase‑2 plot.")


# Windows-safe guard (avoid FVA worker re-imports)
if __name__ == "__main__":
    main()
