# therapy_engine.py
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple


TriggerRule = Tuple[Optional[float], Optional[float], str]


class ConflictFinder:
    """
    Placeholder for your App\Services\ConflictFinder.

    Replace this with your real implementation (or inject it).
    """
    def detect_realtime(self, reading: Dict[str, Any], therapy_payload: Dict[str, Any]) -> None:
        # Implement your conflict logic here.
        # reading: {'name': ..., 'value': ..., 'id': ...}
        # therapy_payload: {'therapy': ..., 'rate': ..., 'reason': ...}
        return


@dataclass
class TherapyEngine:
    # ------------------------------- STATIC CONFIG ---------------------- #

    therapy_base: Dict[str, float] = field(default_factory=lambda: {
        "Insulin": 1,
        "Dextrose": 0,
        "KCl": 0,
        "Heparin": 800,
        "Protamine": 1,
        "CRRT": 150,
        "Tocilizumab": 0,
        "ABX": 0.5,
        "UF_Rate": 150,
        "VasoDose": 0.001,
        "Norepi": 0.05,
        "HTS": 0,
    })

    therapy_max: Dict[str, float] = field(default_factory=lambda: {
        "Insulin": 40,
        "Dextrose": 200,
        "KCl": 40,
        "Heparin": 2000,
        "Protamine": 50,
        "CRRT": 2000,
        "VasoDose": 1.0,
        "UF_Rate": 2000,
        "Tocilizumab": 800,
        "ABX": 4,
        "Norepi": 1.0,
        "HTS": 500,
    })

    sensor_hi_lo: Dict[str, Dict[str, float]] = field(default_factory=lambda: {
        "K": {"lo": 3.0, "hi": 7.5},
        "Na": {"lo": 115, "hi": 155},
        "IL6": {"lo": 0, "hi": 100},
        "Pct": {"lo": 0, "hi": 5},
        "Glucose": {"lo": 70, "hi": 600},
        "pH": {"lo": 7.0, "hi": 7.55},
        "antiXa": {"lo": 0, "hi": 1.5},
        "SerumLactate": {"lo": 0, "hi": 5},
        "Creatinine": {"lo": 0.4, "hi": 4.0},
        "Platelets": {"lo": 20, "hi": 600},
        "INR": {"lo": 0.8, "hi": 4.0},
        "PaO2_FiO2": {"lo": 80, "hi": 400},
        "Temp": {"lo": 34.0, "hi": 40.0},
        "MAP": {"lo": 50, "hi": 110},

        "HR": {"lo": 40, "hi": 180},
        "CVP": {"lo": 0, "hi": 20},
        "CO": {"lo": 2.0, "hi": 8.0},

        "Bicarbonate": {"lo": 10, "hi": 40},
        "Mg2": {"lo": 1.0, "hi": 4.0},

        "BUN": {"lo": 5, "hi": 80},
        "UrineOutput": {"lo": 0, "hi": 300},

        "CRP": {"lo": 0, "hi": 300},
        "Ferritin": {"lo": 10, "hi": 2000},
        "D_Dimer": {"lo": 0.0, "hi": 5.0},
        "SpO2": {"lo": 70, "hi": 100},

        "Hb": {"lo": 5.0, "hi": 18.0},
    })

    hierarchy: Dict[str, int] = field(default_factory=lambda: {
        "VasoDose": 0,
        "Norepi": 0,
        "KCl": 1,
        "HTS": 1,
        "Heparin": 2,
        "Protamine": 2,
        "Insulin": 3,
        "CRRT": 3,
        "Tocilizumab": 4,
        "ABX": 4,
    })

    triggers: Dict[str, List[TriggerRule]] = field(default_factory=lambda: {
        "MAP": [
            (55, None, "VasoDose"),
            (65, 90, "Norepi"),
        ],
        "CVP": [
            (None, 12, "UF_Rate"),
        ],
        "PaO2_FiO2": [
            (200, None, "UF_Rate"),
        ],
        "K": [
            (3.5, 5.0, "KCl"),
        ],
        "Na": [
            (130, 150, "HTS"),
        ],
        "Glucose": [
            (80, 180, "Insulin"),
        ],
        "pH": [
            (7.20, 7.45, "CRRT"),
        ],
        "antiXa": [
            (0.2, 0.8, "Heparin"),
            (None, 0.7, "Protamine"),
        ],
        "IL6": [
            (None, 50, "Tocilizumab"),
        ],
        "Pct": [
            (None, 2, "ABX"),
        ],
        "Creatinine": [
            (None, 2.0, "CRRT"),
        ],
    })

    triggers_lite: Dict[str, List[TriggerRule]] = field(default_factory=lambda: {
        "MAP": [
            (65, 90, "Norepi"),
        ],
        "K": [
            (3.5, 4.5, "KCl"),
        ],
        "Glucose": [
            (80, 180, "Insulin"),
            (None, 70, "Dextrose"),
        ],
    })

    conflict_finder: Optional[ConflictFinder] = None

    # ----------------------------- PUBLIC API ---------------------------- #

    def compute_doses(
        self,
        labs: Dict[str, Dict[str, Any]],
        vitals: Optional[Dict[str, Dict[str, Any]]] = None,
        lite: bool = False,
    ) -> Dict[str, Dict[str, Any]]:
        readings: Dict[str, Dict[str, Any]] = dict(labs)
        if vitals:
            readings.update(vitals)

        active_triggers = self.triggers_lite if lite else self.triggers

        candidates: List[Dict[str, Any]] = []
        therapy_by_param: Dict[str, List[Dict[str, Any]]] = {}

        # Build candidate therapies
        for param, rules in active_triggers.items():
            if param not in readings:
                continue

            v = float(readings[param]["value"])

            for lo, hi, therapy in rules:
                breach_low = (lo is not None and v < lo)
                breach_high = (hi is not None and v > hi)
                if not (breach_low or breach_high):
                    continue

                if breach_low:
                    delta = (lo - v)  # positive magnitude
                    direction = "low"
                else:
                    delta = (v - hi)  # positive magnitude
                    direction = "high"

                candidates.append({
                    "therapy": therapy,
                    "delta": float(delta),
                    "sev": abs(float(delta)),
                    "value": v,
                    "param": param,
                    "direction": direction,
                })

        # Sort by hierarchy priority, then severity desc
        candidates.sort(
            key=lambda c: (self.hierarchy.get(c["therapy"], 10**9), -c["sev"])
        )

        out: Dict[str, Dict[str, Any]] = {}

        for c in candidates:
            t = c["therapy"]
            if t in out:
                continue  # only one entry per therapy

            base = float(self.therapy_base.get(t, 0))
            rate = self._dose_adjust(t, c["delta"], base)

            computed = {
                "therapy": t,
                "rate": rate,
                "reason": c["direction"],   # FIXED vs PHP
                "param": c["param"],
                "value": c["value"],
                "delta": c["delta"],
            }

            out[t] = computed
            therapy_by_param.setdefault(c["param"], []).append(computed)

        # Conflict detection per reading/param
        if self.conflict_finder:
            for name, reading in readings.items():
                payloads = therapy_by_param.get(name, [])
                if not payloads:
                    continue
                for tp in payloads:
                    self.conflict_finder.detect_realtime(
                        {"name": name, "value": reading["value"], "id": reading.get("id")},
                        tp,
                    )

        return out

    def compute_doses_lite_from_labs(self, labs: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
        return self.compute_doses(labs=labs, vitals={}, lite=True)

    def normalise_rate(self, therapy: str, rate: float) -> float:
        m = float(self.therapy_max.get(therapy, 1.0))
        return round((rate / m) * 100.0, 1)

    def get_display_ranges(self) -> Dict[str, Dict[str, float]]:
        """
        Return {therapy: {lo, hi}} for trigger-defined ranges,
        merged with sensor hi/lo ranges.
        """
        out: Dict[str, Dict[str, float]] = {}

        for _param, rules in self.triggers.items():
            for lo, hi, therapy in rules:
                out[therapy] = {"lo": lo, "hi": hi}

        # Merge in sensor ranges (FIXED vs PHP)
        out.update(self.sensor_hi_lo)
        return out

    def get_therapy_max(self) -> Dict[str, float]:
        return dict(self.therapy_max)

    def get_therapy_base(self) -> Dict[str, float]:
        return dict(self.therapy_base)

    # -------------------------- INTERNAL -------------------------------- #

    def _dose_adjust(self, therapy: str, delta: float, base: float) -> float:
        g = {
            "VasoDose": 0.01,
            "Norepi": 0.03,
            "KCl": 5,
            "HTS": 1.5,
            "Heparin": 200,
            "Protamine": 10,
            "Insulin": 0.5,
            "CRRT": 20,
        }.get(therapy, 1)

        return max(0.0, round(base + g * float(delta), 3))
