#!/usr/bin/env python3
from __future__ import annotations

import os
import time
import random
from datetime import datetime, timezone
from uuid import uuid4
from typing import Optional, Tuple

import requests

from token_service import get_or_create_epic_token
from db import SessionLocal
from observation_store import upsert_observation

FHIR_BASE = os.getenv(
    "FHIR_BASE_URL",
    "https://fhir.epic.com/interconnect-fhir-oauth/api/FHIR/R4",
)
EPIC_APP_CONFIG_ID = int(os.getenv("EPIC_APP_CONFIG_ID", "3"))

PATIENT_REF = os.getenv("PATIENT_REF", "Patient/e63wRTbPfr1p8UW81d8Seiw3")
ENCOUNTER_REF = os.getenv("ENCOUNTER_REF", "Encounter/ewYPWmestytGzRSBFlutk7w3")

INTERVAL_SECONDS = float(os.getenv("INTERVAL_SECONDS", "1"))
REQUEST_TIMEOUT = float(os.getenv("REQUEST_TIMEOUT", "15"))

ORG_ID = os.getenv("ORG_ID", "default_org")
DEBUG_STORE = os.getenv("DEBUG_STORE", "1").lower() in ("1", "true", "yes", "y")

# used for correlation if Prefer/Location are not available
IDENTIFIER_SYSTEM = os.getenv("OBS_IDENTIFIER_SYSTEM", "urn:uuid")


def normalize_resource_url(fhir_base: str, resource: str) -> str:
    base = fhir_base.strip().rstrip("/")
    lowered = base.lower()
    r = resource.lower()
    if lowered.endswith(f"/{r}"):
        return base
    if lowered.endswith("/fhir/r4"):
        return base + f"/{resource}"
    return base + f"/{resource}"


def now_fhir_dt() -> str:
    return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")


def build_spo2_observation(value: int) -> Tuple[dict, str]:
    """
    IMPORTANT:
    - Don't send "id" on create. Epic assigns Observation.id.
    - Use identifier for correlation so you can search after POST if needed.
    """
    client_uuid = uuid4().hex
    payload = {
        "resourceType": "Observation",
        "identifier": [{"system": IDENTIFIER_SYSTEM, "value": client_uuid}],
        "status": "final",
        "category": [
            {
                "coding": [
                    {
                        "system": "http://hl7.org/fhir/us/core/StructureDefinition/us-core-blood-pressure",
                        "code": "vital-signs",
                        "display": "Vital Signs",
                    }
                ]
            }
        ],
        "code": {
            "coding": [
                {
                    "system": "http://loinc.org",
                    "code": "2708-6",
                    "display": "Oxygen saturation in Arterial blood by Pulse oximetry",
                }
            ],
            "text": "SpO2",
        },
        "meta": { "profile": ["http://hl7.org/fhir/us/core/StructureDefinition/us-core-blood-pressure"] },
        "subject": {"reference": PATIENT_REF},
        "encounter": {"reference": ENCOUNTER_REF},
        "effectiveDateTime": now_fhir_dt(),
        "valueQuantity": {
            "value": int(value),
            "unit": "%",
            "system": "http://unitsofmeasure.org",
            "code": "%",
        },
    }
    return payload, client_uuid


def get_headers() -> dict:
    token = get_or_create_epic_token(EPIC_APP_CONFIG_ID)
    return {
        "Authorization": f"Bearer {token}",
        "Accept": "application/fhir+json",
        "Content-Type": "application/fhir+json",
        # Ask server to return created Observation in response body (best case)
        "Prefer": "return=representation",
    }


def _post_with_refresh(url: str, headers: dict, payload: dict) -> requests.Response:
    r = requests.post(url, headers=headers, json=payload, timeout=REQUEST_TIMEOUT)
    if r.status_code == 401:
        headers = get_headers()
        r = requests.post(url, headers=headers, json=payload, timeout=REQUEST_TIMEOUT)
    return r


def _get_with_refresh(url: str, headers: dict) -> requests.Response:
    r = requests.get(url, headers=headers, timeout=REQUEST_TIMEOUT)
    if r.status_code == 401:
        headers = get_headers()
        r = requests.get(url, headers=headers, timeout=REQUEST_TIMEOUT)
    return r


def fetch_created_from_location(resp: requests.Response, headers: dict) -> Optional[dict]:
    loc = resp.headers.get("Location") or resp.headers.get("Content-Location")
    if not loc:
        return None

    r2 = _get_with_refresh(loc, headers=headers)
    if r2.status_code not in (200, 201):
        return None

    try:
        return r2.json()
    except Exception:
        return None


def fetch_by_identifier(observation_url: str, headers: dict, client_uuid: str) -> Optional[dict]:
    """
    GET Observation?identifier=<system>|<value>
    Returns a Bundle; we pick first Observation.
    """
    params = {"identifier": f"{IDENTIFIER_SYSTEM}|{client_uuid}"}
    r = requests.get(observation_url, headers=headers, params=params, timeout=REQUEST_TIMEOUT)
    if r.status_code == 401:
        headers = get_headers()
        r = requests.get(observation_url, headers=headers, params=params, timeout=REQUEST_TIMEOUT)

    if r.status_code != 200:
        return None

    try:
        bundle = r.json()
    except Exception:
        return None

    if not isinstance(bundle, dict) or bundle.get("resourceType") != "Bundle":
        return None

    for entry in (bundle.get("entry") or []):
        res = (entry or {}).get("resource") or {}
        if res.get("resourceType") == "Observation":
            return res

    return None


def post_and_fetch_observation(observation_url: str, payload: dict, client_uuid: str) -> dict:
    """
    1) POST with Prefer: return=representation
    2) If response not Observation, try Location
    3) If still missing, search by identifier
    """
    headers = get_headers()
    resp = _post_with_refresh(observation_url, headers=headers, payload=payload)

    if resp.status_code not in (200, 201):
        try:
            err = resp.json()
        except Exception:
            err = resp.text
        raise RuntimeError(f"POST failed status={resp.status_code} body={err}")

    # Prefer: return=representation
    try:
        body = resp.json()
        if isinstance(body, dict) and body.get("resourceType") == "Observation":
            return body
    except Exception:
        pass

    # Location header fetch
    obs = fetch_created_from_location(resp, headers=headers)
    if isinstance(obs, dict) and obs.get("resourceType") == "Observation":
        return obs

    # Identifier search fallback
    obs = fetch_by_identifier(observation_url, headers=headers, client_uuid=client_uuid)
    if isinstance(obs, dict) and obs.get("resourceType") == "Observation":
        return obs

    raise RuntimeError("Created Observation could not be fetched (no body, no Location, identifier search failed).")


def main() -> int:
    observation_url = normalize_resource_url(FHIR_BASE, "Observation")

    print("Posting SpO2 every", INTERVAL_SECONDS, "seconds")
    print("Observation endpoint:", observation_url)
    print("Patient:", PATIENT_REF)
    print("Encounter:", ENCOUNTER_REF)
    print("ORG_ID:", ORG_ID)
    print("DEBUG_STORE:", DEBUG_STORE)

    while True:
        spo2 = random.randint(90, 100)
        payload, client_uuid = build_spo2_observation(spo2)

        try:
            created_obs = post_and_fetch_observation(observation_url, payload, client_uuid)
            epic_obs_id = created_obs.get("id")

            print(
                f"[{datetime.now().isoformat(timespec='seconds')}] "
                f"ok spo2={spo2} epic_observation_id={epic_obs_id} client_uuid={client_uuid}"
            )

            # Insert/Update into MySQL using your SessionLocal
            with SessionLocal() as session:
                row = upsert_observation(session, org_id=ORG_ID, obs=created_obs, debug=DEBUG_STORE)
                if row is not None:
                    print(f"DB upsert ok observation_id={row.observation_id}")
                else:
                    print("DB upsert skipped (missing Observation.id)")

        except Exception as e:
            print("Cycle failed:", repr(e))

        time.sleep(INTERVAL_SECONDS)

    # unreachable
    # return 0


if __name__ == "__main__":
    raise SystemExit(main())