from __future__ import annotations

from datetime import datetime, timezone
from sqlalchemy.orm import Session
from epic_token import EpicObservation
import json

DEBUG_LOG_PATH = "observation_debug.json"


def parse_fhir_dt(s: str):
    if not s:
        return None
    try:
        if s.endswith("Z"):
            return datetime.fromisoformat(s.replace("Z", "+00:00")).astimezone(timezone.utc)
        return datetime.fromisoformat(s)
    except ValueError:
        return None


def ref_id(ref: str):
    # "Patient/abc" -> "abc"
    if not ref:
        return None
    return ref.split("/", 1)[1] if "/" in ref else ref


def pick_primary_coding(codeable: dict):
    if not codeable:
        return (None, None, None)
    codings = codeable.get("coding") or []
    if not codings:
        return (None, None, None)
    c0 = codings[0] or {}
    return (c0.get("system"), c0.get("code"), c0.get("display"))


def extract_value_fields(obs: dict):
    """
    Supports common value[x] shapes and returns a normalized dict.
    """
    out = {
        "value_type": None,
        "value_string": None,
        "value_code": None,
        "value_system": None,
        "value_display": None,
        "value_unit": None,
        "value_number": None,
    }

    if "valueString" in obs:
        out["value_type"] = "valueString"
        out["value_string"] = obs.get("valueString")

    elif "valueCodeableConcept" in obs:
        out["value_type"] = "valueCodeableConcept"
        v = obs.get("valueCodeableConcept") or {}
        sys, code, disp = pick_primary_coding(v)
        out["value_system"] = sys
        out["value_code"] = code
        out["value_display"] = v.get("text") or disp

    elif "valueQuantity" in obs:
        out["value_type"] = "valueQuantity"
        q = obs.get("valueQuantity") or {}
        val = q.get("value")
        out["value_number"] = None if val is None else str(val)
        out["value_unit"] = q.get("unit")
        out["value_system"] = q.get("system")
        out["value_code"] = q.get("code")

    elif "valueBoolean" in obs:
        out["value_type"] = "valueBoolean"
        out["value_string"] = str(obs.get("valueBoolean"))

    elif "valueInteger" in obs:
        out["value_type"] = "valueInteger"
        out["value_number"] = str(obs.get("valueInteger"))

    elif "valueDateTime" in obs:
        out["value_type"] = "valueDateTime"
        out["value_string"] = obs.get("valueDateTime")

    return out


def _to_json_text(value):
    if value is None:
        return None
    try:
        return json.dumps(value, ensure_ascii=False)
    except TypeError:
        return json.dumps(value, default=str, ensure_ascii=False)


def upsert_observation(session: Session, org_id: str, obs: dict, debug: bool = False):
    obs_id = obs.get("id")
    if not obs_id:
        return None

    row = session.query(EpicObservation).filter_by(observation_id=obs_id).one_or_none()
    if not row:
        row = EpicObservation(observation_id=obs_id)

    row.org_id = org_id
    row.fhir_id = obs_id

    subj = obs.get("subject") or {}
    row.patient_reference = subj.get("reference")
    row.patient_id = ref_id(subj.get("reference")) or ""  # FIX: store just the id

    enc = obs.get("encounter") or {}
    row.encounter_reference = enc.get("reference")

    spec = obs.get("specimen") or {}
    row.specimen_reference = spec.get("reference")

    row.status = obs.get("status")

    code_obj = obs.get("code") or {}
    row.code_text = code_obj.get("text")

    coding0 = (code_obj.get("coding") or [{}])[0] or {}
    row.code_primary_system = coding0.get("system")
    row.code_primary_code = coding0.get("code")
    row.code_primary_display = coding0.get("display")

    row.effective_datetime = parse_fhir_dt(obs.get("effectiveDateTime"))
    row.issued_datetime = parse_fhir_dt(obs.get("issued"))

    v = extract_value_fields(obs)
    row.value_type = v.get("value_type")
    row.value_string = v.get("value_string")
    row.value_code = v.get("value_code")
    row.value_system = v.get("value_system")
    row.value_display = v.get("value_display")
    row.value_unit = v.get("value_unit")
    row.value_number = v.get("value_number")

    # JSON fields
    row.category_json = _to_json_text(obs.get("category"))
    row.interpretation_json = _to_json_text(obs.get("interpretation"))
    row.note_json = _to_json_text(obs.get("note"))
    row.has_member_json = _to_json_text(obs.get("hasMember"))
    row.based_on_json = _to_json_text(obs.get("basedOn"))
    row.reference_range_json = _to_json_text(obs.get("referenceRange"))
    row.raw_json = _to_json_text(obs)

    if debug:
        payload = _build_obs_debug_payload(obs, row, v)
        _store_obs_debug(payload)

    session.add(row)
    session.commit()
    session.refresh(row)
    return row


def _fmt_value(v: dict) -> str:
    t = v.get("value_type")
    if t == "valueQuantity":
        num = v.get("value_number")
        unit = v.get("value_unit")
        return f"valueQuantity: {num} {unit}".strip()
    if t == "valueCodeableConcept":
        return f"valueCodeableConcept: {v.get('value_display')} ({v.get('value_system')}|{v.get('value_code')})"
    if t in ("valueString", "valueBoolean", "valueDateTime"):
        return f"{t}: {v.get('value_string')}"
    if t == "valueInteger":
        return f"valueInteger: {v.get('value_number')}"
    return "value: <none>"


def _build_obs_debug_payload(obs: dict, row, v: dict) -> dict:
    return {
        "observation_id": obs.get("id"),
        "status": obs.get("status"),
        "patient_reference": (obs.get("subject") or {}).get("reference"),
        "effectiveDateTime": obs.get("effectiveDateTime"),
        "issued": obs.get("issued"),

        "code_text": row.code_text,
        "code_primary_display": row.code_primary_display,
        "code_primary_system": row.code_primary_system,
        "code_primary_code": row.code_primary_code,

        "value": _fmt_value(v),
        "value_type": v.get("value_type"),
        "value_number": v.get("value_number"),
        "value_unit": v.get("value_unit"),
        "value_string": v.get("value_string"),
        "value_system": v.get("value_system"),
        "value_code": v.get("value_code"),
        "value_display": v.get("value_display"),

        "raw_value_keys": [k for k in obs.keys() if k.startswith("value")],
    }


def _store_obs_debug(payload: dict, path: str = DEBUG_LOG_PATH) -> None:
    # JSON Lines: one JSON object per line
    try:
        with open(path, "a", encoding="utf-8") as f:
            f.write(json.dumps(payload, ensure_ascii=False, default=str) + "\n")
    except Exception as e:
        print(f"[WARN] failed to write debug log: {e}")


def store_observation_bundle(session: Session, org_id: str, bundle: dict) -> int:
    count = 0
    for e in (bundle.get("entry") or []):
        res = (e or {}).get("resource") or {}
        upsert_observation(session, org_id, res)
        count += 1
    return count
