#!/usr/bin/env python3
"""POST a FHIR R4 Observation (vitals) to an Epic-like FHIR endpoint.

This version ALWAYS gets the bearer token from token_service.get_or_create_epic_token().

Additionally:
- Searches Encounter using /api/FHIR/R4/Encounter?... parameters
- Finds an Encounter id for a given patient and prints it
"""

from __future__ import annotations

import argparse
import json
import os
import sys
from datetime import datetime
from typing import Any, Optional, Tuple
from uuid import uuid4

import requests

from token_service import get_or_create_epic_token


DEFAULT_PATIENT_ID = "e63wRTbPfr1p8UW81d8Seiw3"

# Mapping of critical LOINC codes to internal vital types
CRITICAL_VITAL_LOINC = {
    "8867-4": "heart_rate",               # Heart rate
    "9279-1": "respiratory_rate",         # Respiratory rate
    "8310-5": "body_temperature",         # Body temperature
    "2708-6": "oxygen_saturation",        # SpO2
    "8302-2": "body_height",              # Height
    "29463-7": "body_weight",             # Weight
    "8478-0": "mean_arterial_pressure",   # MAP
    "8480-6": "systolic_blood_pressure",  # SBP
    "8462-4": "diastolic_blood_pressure", # DBP
}


def _flatten_vital_observation(obs: dict[str, Any]) -> list[dict[str, Any]]:
    """Convert a FHIR Observation (single or panel) into discrete rows (client-side view)."""
    subject = (obs.get("subject") or {}).get("reference")
    if not subject:
        return []

    base = {
        "observation_id": obs.get("id") or uuid4().hex,
        "patient_ref": subject,
        "encounter_ref": (obs.get("encounter") or {}).get("reference"),
        "effective_dt": obs.get("effectiveDateTime"),
        "status": obs.get("status", "unknown"),
    }

    rows: list[dict[str, Any]] = []

    def _row(coding: dict[str, Any], value: dict[str, Any]) -> Optional[dict[str, Any]]:
        code = coding.get("code")
        if not code or code not in CRITICAL_VITAL_LOINC:
            return None  # Ignore non-critical vitals for now

        val = value.get("value")
        if val is None:
            return None

        r = dict(base)
        r.update(
            {
                "loinc_code": code,
                "vital_type": CRITICAL_VITAL_LOINC[code],
                "value": val,
                "unit": value.get("unit"),
            }
        )
        return r

    # Panel (e.g. BP) with components
    if "component" in obs:
        for comp in obs.get("component") or []:
            coding = (((comp.get("code") or {}).get("coding")) or [{}])[0]
            value_q = comp.get("valueQuantity") or {}
            row = _row(coding, value_q)
            if row:
                rows.append(row)
    else:
        coding = (((obs.get("code") or {}).get("coding")) or [{}])[0]
        value_q = obs.get("valueQuantity") or {}
        row = _row(coding, value_q)
        if row:
            rows.append(row)

    return rows


def normalize_resource_url(fhir_base: str, resource: str) -> str:
    """Build a resource endpoint from a FHIR base URL."""
    base = fhir_base.strip().rstrip("/")
    lowered = base.lower()
    resource_lower = resource.lower()

    # If caller mistakenly passes a direct /Resource endpoint, keep it
    if lowered.endswith(f"/{resource_lower}"):
        return base

    # Typical base: .../FHIR/R4
    if lowered.endswith("/fhir/r4"):
        return base + f"/{resource}"

    # Fallback: still append /Resource
    return base + f"/{resource}"


def to_patient_reference(patient: str) -> str:
    """Convert a patient id/reference to a FHIR Reference string."""
    p = patient.strip()
    if not p:
        return p
    # If already a reference like "Patient/123" keep it
    if "/" in p:
        return p
    return f"Patient/{p}"


def _parse_fhir_datetime(dt_str: str) -> Optional[datetime]:
    """Parse FHIR dateTime-ish strings (best effort)."""
    if not dt_str:
        return None
    s = dt_str.strip()
    try:
        # Handle "Z"
        if s.endswith("Z"):
            s = s[:-1] + "+00:00"
        return datetime.fromisoformat(s)
    except Exception:
        return None


def pick_encounter_id(bundle: dict[str, Any]) -> Optional[str]:
    """
    Pick an Encounter id from a FHIR Bundle.
    Prefers the entry with the latest period.start if present.
    """
    entries = bundle.get("entry") or []
    if not entries:
        return None

    candidates: list[Tuple[Optional[datetime], Optional[str]]] = []
    for e in entries:
        res = (e or {}).get("resource") or {}
        if res.get("resourceType") != "Encounter":
            continue
        enc_id = res.get("id")
        start = ((res.get("period") or {}).get("start")) or ""
        dt = _parse_fhir_datetime(start)
        candidates.append((dt, enc_id))

    # If we found no Encounter resources, fall back to first entry.resource.id
    if not candidates:
        first_res = (entries[0] or {}).get("resource") or {}
        return first_res.get("id")

    # Sort with None dates last
    candidates.sort(key=lambda x: (x[0] is None, x[0]), reverse=True)
    return candidates[0][1]


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument(
        "--fhir-base",
        default=os.getenv("FHIR_BASE_URL"),
        help="FHIR base URL (e.g., https://hostname/.../api/FHIR/R4). Defaults to env FHIR_BASE_URL.",
    )
    ap.add_argument(
        "--app-config-id",
        type=int,
        default=int(os.getenv("EPIC_APP_CONFIG_ID", "3")),
        help="EpicAppConfig id for get_or_create_epic_token(). Defaults to env EPIC_APP_CONFIG_ID or 3.",
    )
    ap.add_argument("--json", default="observation_vitals.json", help="Path to Observation JSON payload.")
    ap.add_argument("--timeout", type=int, default=60, help="HTTP timeout seconds.")

    # Encounter search params you mentioned:
    ap.add_argument("--enc-class", dest="enc_class", help="Encounter search param: class")
    ap.add_argument("--enc-date", dest="enc_date", help="Encounter search param: date")
    ap.add_argument("--enc-identifier", dest="enc_identifier", help="Encounter search param: identifier")
    ap.add_argument(
        "--enc-onlyscannable",
        dest="enc_onlyscannable",
        help="Encounter search param: onlyscannable (server-specific, if supported)",
    )
    ap.add_argument(
        "--patient",
        default=DEFAULT_PATIENT_ID,
        help=f"Patient id or reference. Default is {DEFAULT_PATIENT_ID}",
    )
    ap.add_argument("--enc-subject", dest="enc_subject", help="Encounter search param: subject (reference or id)")

    args = ap.parse_args()

    args.fhir_base = 'https://fhir.epic.com/interconnect-fhir-oauth/api/FHIR/R4'

    # Always fetch token from token_service
    token = get_or_create_epic_token(args.app_config_id)

    headers = {
        "Authorization": f"Bearer {token}",
        "Accept": "application/fhir+json",
        "Content-Type": "application/fhir+json",
    }

    # ---- 1) Encounter search ----
    encounter_url = normalize_resource_url(args.fhir_base, "Encounter")

    params: dict[str, str] = {}

    # We'll send patient=Patient/{id} for maximum compatibility.
    params["patient"] = to_patient_reference(args.patient)

    if args.enc_subject:
        params["subject"] = to_patient_reference(args.enc_subject)

    params["status"] = "in-progress"
    try:
        enc_resp = requests.get(encounter_url, headers=headers, params=params, timeout=args.timeout)
    except requests.RequestException as e:
        print(f"ERROR: Encounter search request failed: {e}", file=sys.stderr)
        return 1

    if enc_resp.status_code != 200:
        try:
            print(json.dumps(enc_resp.json(), indent=2))
        except Exception:
            print(enc_resp.text)
        return 1

    try:
        bundle = enc_resp.json()
    except Exception:
        print("ERROR: Encounter search did not return JSON", file=sys.stderr)
        print(enc_resp.text)
        return 1

    print(json.dumps(bundle, indent=2))

    # ---- 2) Post Observation ----
    observation_url = normalize_resource_url(args.fhir_base, "Observation")

    with open(args.json, "r", encoding="utf-8") as f:
        payload = json.load(f)

    # NEW: flatten + gate on "critical vitals"
    critical_rows = _flatten_vital_observation(payload)
    print("Flattened critical vitals rows:")
    print(json.dumps(critical_rows, indent=2))

    if not critical_rows:
        print("No critical vitals found in payload; skipping Observation POST.")
        return 0

    try:
        resp = requests.post(observation_url, headers=headers, json=payload, timeout=args.timeout)
    except requests.RequestException as e:
        print(f"ERROR: Observation POST failed: {e}", file=sys.stderr)
        return 1

    print(f"Observation POST URL: {observation_url}")
    print("Observation POST HTTP:", resp.status_code)

    try:
        data = resp.json()
        print(json.dumps(data, indent=2))
    except Exception:
        print(resp.text)

    return 0 if resp.status_code in (200, 201) else 1


if __name__ == "__main__":
    raise SystemExit(main())
