# fhir_replay_server.py
"""Minimal in-memory FHIR R4 replay server (Patient, Encounter, Observation)
plus a **Mirror Bridge** utility that streams data from the Epic public sandbox
into this local server in near-real-time.

▶ Launch server only
    $ pip install fastapi uvicorn python-dateutil httpx
    $ uvicorn fhir_replay_server:app --reload --port 8000

▶ Launch mirror-bridge (runs alongside the server)
    $ export EPIC_SANDBOX_TOKEN="<your bearer token>"
    $ python fhir_replay_server.py mirror --patients 607923,649012 --interval 60

Optional env:
    EPIC_CLIENT_ID / EPIC_CLIENT_SECRET  (if you want client_credentials token fetch)
    LOCAL_FHIR_BASE                      (default http://localhost:8000/fhir/R4/)

The mirror task:
    • Polls Epic sandbox for new Observations (_lastUpdated filter)
    • Immediately POSTs them into this local FastAPI server
    • Creates missing Patient + Encounter resources on-the-fly

Result: TraceLoop can swap between
    http://localhost:8000/fhir/R4  ← mirror feed
    https://fhir.epic.com/interconnect-fhir-oauth/api/FHIR/R4  ← real Epic
with no code changes.
"""
from __future__ import annotations

import os
import sys
import jwt, uuid, requests, time
import json
import asyncio
import uuid
from datetime import datetime, timezone
from cryptography.hazmat.primitives import serialization
from typing import Dict, Any, List, Optional, Tuple

import httpx
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from dateutil.parser import parse as dtparse


app = FastAPI(title="FHIR Replay Server + Mirror Bridge")

# ---------------------------------------------------------------------------
# In-memory stores (ONLY for the uvicorn server process)
patients: Dict[str, Dict[str, Any]] = {}
encounters: Dict[str, Dict[str, Any]] = {}
observations: Dict[str, Dict[str, Any]] = {}

FHIR_CT = "application/fhir+json; charset=utf-8"


def now_iso() -> str:
    return datetime.now(timezone.utc).isoformat()


# ---------------------------------------------------------------------------
# Helper: FHIR response wrappers

def fhir_json(payload: Any, status: int = 200) -> JSONResponse:
    return JSONResponse(payload, status_code=status, media_type=FHIR_CT)


def op_outcome(msg: str, status: int = 400, code: str = "invalid") -> JSONResponse:
    return fhir_json(
        {
            "resourceType": "OperationOutcome",
            "issue": [{"severity": "error", "code": code, "diagnostics": msg}],
        },
        status,
    )


# ---------------------------------------------------------------------------
# CRUD helpers (minimal validation)

def ensure_id(res: Dict[str, Any]) -> str:
    rid = res.get("id") or str(uuid.uuid4())
    res["id"] = rid
    return rid


def touch_last_updated(res: Dict[str, Any]) -> None:
    # Preserve upstream meta.lastUpdated if present; otherwise stamp now.
    meta = res.setdefault("meta", {})
    meta.setdefault("lastUpdated", now_iso())


def save_resource(store: Dict[str, Any], res: Dict[str, Any]) -> None:
    ensure_id(res)
    touch_last_updated(res)
    store[res["id"]] = res


# ---------------------------------------------------------------------------
# FastAPI endpoints

@app.post("/fhir/R4/Patient", status_code=201)
async def post_patient(r: Request):
    res = await r.json()
    if res.get("resourceType") != "Patient":
        raise HTTPException(400, "resourceType must be Patient")
    save_resource(patients, res)
    return fhir_json(res, 201)


@app.get("/fhir/R4/Patient/{pid}")
async def get_patient(pid: str):
    p = patients.get(pid)
    if not p:
        return op_outcome(f"Patient/{pid} not found", 404, "not-found")
    return fhir_json(p)


@app.get("/fhir/R4/Patient")
async def search_patient(identifier: Optional[str] = None, _count: int = 50):
    _count = max(1, min(int(_count), 200))
    items = list(patients.values())

    if identifier:
        # naive containment check for replay usage
        items = [p for p in items if identifier in json.dumps(p.get("identifier", []), ensure_ascii=False)]

    items = items[:_count]
    bundle = {
        "resourceType": "Bundle",
        "type": "searchset",
        "total": len(items),
        "entry": [{"resource": p} for p in items],
    }
    return fhir_json(bundle)


@app.post("/fhir/R4/Encounter", status_code=201)
async def post_encounter(r: Request):
    res = await r.json()
    if res.get("resourceType") != "Encounter":
        raise HTTPException(400, "resourceType must be Encounter")
    if "subject" not in res or "reference" not in (res.get("subject") or {}):
        return op_outcome("Encounter.subject.reference required")
    save_resource(encounters, res)
    return fhir_json(res, 201)


@app.get("/fhir/R4/Encounter/{eid}")
async def get_encounter(eid: str):
    e = encounters.get(eid)
    if not e:
        return op_outcome(f"Encounter/{eid} not found", 404, "not-found")
    return fhir_json(e)


@app.get("/fhir/R4/Encounter")
async def search_encounter(patient: Optional[str] = None, _count: int = 50):
    _count = max(1, min(int(_count), 200))
    items = list(encounters.values())
    if patient:
        items = [e for e in items if patient in (e.get("subject") or {}).get("reference", "")]
    items = items[:_count]
    bundle = {
        "resourceType": "Bundle",
        "type": "searchset",
        "total": len(items),
        "entry": [{"resource": e} for e in items],
    }
    return fhir_json(bundle)


@app.post("/fhir/R4/Observation", status_code=201)
async def post_observation(r: Request):
    res = await r.json()
    if res.get("resourceType") != "Observation":
        raise HTTPException(400, "resourceType must be Observation")
    if "subject" not in res:
        return op_outcome("Observation.subject required")
    save_resource(observations, res)
    return fhir_json(res, 201)


@app.get("/fhir/R4/Observation/{oid}")
async def get_observation(oid: str):
    o = observations.get(oid)
    if not o:
        return op_outcome(f"Observation/{oid} not found", 404, "not-found")
    return fhir_json(o)


# ---------------------------------------------------------------------------
# Tiny JSONPath helper (no dependency)

def json_path(obj: Dict[str, Any], path: str, default=None):
    cur: Any = obj
    for part in path.split("."):
        if isinstance(cur, list):
            try:
                idx = int(part)
            except Exception:
                return default
            if idx < 0 or idx >= len(cur):
                return default
            cur = cur[idx]
            continue

        if not isinstance(cur, dict):
            return default

        if part not in cur:
            return default
        cur = cur[part]
    return cur


def _parse_dt(s: str) -> Optional[datetime]:
    if not s:
        return None
    try:
        return dtparse(s)
    except Exception:
        return None


def _match_prefix_dt(target_dt_str: str, query_val: str) -> bool:
    """
    Support ge|gt|le|lt|eq prefixes (FHIR style) for dateTime comparisons.
    query_val examples: ge2020-01-01T00:00:00Z
    """
    if not query_val:
        return True
    if not target_dt_str:
        return False

    op = "eq"
    for p in ("ge", "gt", "le", "lt", "eq"):
        if query_val.startswith(p):
            op = p
            query_val = query_val[len(p):]
            break

    t = _parse_dt(target_dt_str)
    q = _parse_dt(query_val)
    if not t or not q:
        return False

    if op == "ge":
        return t >= q
    if op == "gt":
        return t > q
    if op == "le":
        return t <= q
    if op == "lt":
        return t < q
    return t == q


@app.get("/fhir/R4/Observation")
async def search_observation(
    patient: Optional[str] = None,
    category: Optional[str] = None,
    _lastUpdated: Optional[str] = None,
    _count: int = 50,
):
    _count = max(1, min(int(_count), 200))
    items: List[Dict[str, Any]] = list(observations.values())

    if patient:
        items = [o for o in items if patient in (o.get("subject") or {}).get("reference", "")]

    if category:
        items = [
            o
            for o in items
            if category in (json_path(o, "category.0.coding.0.code", default="") or "")
        ]

    if _lastUpdated:
        items = [
            o
            for o in items
            if _match_prefix_dt((o.get("meta") or {}).get("lastUpdated", ""), _lastUpdated)
        ]

    items = sorted(items, key=lambda x: (x.get("meta") or {}).get("lastUpdated", ""), reverse=True)[:_count]
    bundle = {
        "resourceType": "Bundle",
        "type": "searchset",
        "total": len(items),
        "entry": [{"resource": o} for o in items],
    }
    return fhir_json(bundle)


# ---------------------------------------------------------------------------
# MIRROR BRIDGE (Epic → local)

EPIC_BASE  = "https://fhir.epic.com/interconnect-fhir-oauth/api/FHIR/R4/"
TOKEN_URL = "https://fhir.epic.com/interconnect-fhir-oauth/oauth2/token"
LOCAL_BASE_DEFAULT = "http://localhost:8000/fhir/R4/"


async def epic_token(client_id: str, client_secret: str) -> str:
    async with httpx.AsyncClient(timeout=30) as cli:
        r = await cli.post(
            TOKEN_URL,
            data={
                "grant_type": "client_credentials",
                "client_id": client_id,
                "client_secret": client_secret,
            },
        )
        r.raise_for_status()
        return r.json()["access_token"]


async def _local_exists(cli: httpx.AsyncClient, local_base: str, resource_type: str, rid: str) -> bool:
    r = await cli.get(f"{local_base}{resource_type}/{rid}", headers={"Accept": FHIR_CT})
    return r.status_code == 200

CLIENT_ID  = "1fed7724-fddf-469e-8b02-d11f338e264c"
KID        = "jwks_live_patients.json"

with open("keys/privatekey.pem", "rb") as fp:
    KEY = serialization.load_pem_private_key(fp.read(), password=None)

def client_assertion():
    now = int(time.time())
    payload = {
        "iss": CLIENT_ID,
        "sub": CLIENT_ID,
        "aud": TOKEN_URL,
        "jti": str(uuid.uuid4()),
        "iat": now,
        "nbf": now,
        "exp": now + 300,
    }

    headers = {"kid": KID}

    token = jwt.encode(
        payload,
        KEY,
        algorithm="RS384",
        headers=headers,
    )
    return token

def get_token():
    assertion = client_assertion()
    print("Client assertion:\n", assertion)

    r = requests.post(
        TOKEN_URL,
        data={
            "grant_type": "client_credentials",
            "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
            "client_assertion": assertion,
        },
        timeout=15,
    )
    r.raise_for_status()
    return r.json()

async def mirror_task(patient_ids: List[str], interval: int = 60):
    local_base = os.getenv("LOCAL_FHIR_BASE", LOCAL_BASE_DEFAULT).rstrip("/") + "/"

    client_id = os.getenv("EPIC_CLIENT_ID")
    client_secret = os.getenv("EPIC_CLIENT_SECRET")

    tok = get_token()
    access_token = tok["access_token"]
    token = access_token
    if not token:
        if not client_id or not client_secret:
            raise RuntimeError("Missing EPIC_SANDBOX_TOKEN and EPIC_CLIENT_ID/EPIC_CLIENT_SECRET.")
        token = await epic_token(client_id, client_secret)

    epic_hdr = {"Accept": FHIR_CT, "Authorization": f"Bearer {token}"}
    local_hdr = {"Accept": FHIR_CT, "Content-Type": FHIR_CT}

    # Watermark per patient
    last_ts: Dict[str, str] = {pid: "ge2000-01-01T00:00:00Z" for pid in patient_ids}

    async with httpx.AsyncClient(timeout=30) as epic_cli, httpx.AsyncClient(timeout=30) as local_cli:
        while True:
            for pid in patient_ids:
                # 1) pull new Observations since last run
                params = {
                    "patient": pid,
                    "category": 'laboratory'
                }
                r = await epic_cli.get(f"{EPIC_BASE}Observation", params=params, headers=epic_hdr)
                r.raise_for_status()
                bundle = r.json()

                entries = bundle.get("entry") or []
                new_obs: List[Dict[str, Any]] = [e.get("resource") for e in entries if (e.get("resource") or {}).get("resourceType") == "Observation"]
                if not new_obs:
                    continue

                # 2) update watermark based on max meta.lastUpdated from returned obs (fallback to now)
                max_lu: Optional[datetime] = None
                for obs in new_obs:
                    lu = (obs.get("meta") or {}).get("lastUpdated")
                    d = _parse_dt(lu) if lu else None
                    if d and (max_lu is None or d > max_lu):
                        max_lu = d
                last_ts[pid] = "ge" + ((max_lu.isoformat()) if max_lu else now_iso())

                # 3) ensure Patient exists locally
                if not await _local_exists(local_cli, local_base, "Patient", pid):
                    pr = await epic_cli.get(f"{EPIC_BASE}Patient/{pid}", headers=epic_hdr)
                    if pr.status_code == 200:
                        await local_cli.post(f"{local_base}Patient", json=pr.json(), headers=local_hdr)

                # 4) ensure Encounters exist locally (for each Observation)
                needed_enc_ids: set[str] = set()
                for obs in new_obs:
                    enc_ref = (obs.get("encounter") or {}).get("reference") or ""
                    if enc_ref.startswith("Encounter/"):
                        needed_enc_ids.add(enc_ref.split("/")[1])

                for eid in needed_enc_ids:
                    if not await _local_exists(local_cli, local_base, "Encounter", eid):
                        er = await epic_cli.get(f"{EPIC_BASE}Encounter/{eid}", headers=epic_hdr)
                        if er.status_code == 200:
                            await local_cli.post(f"{local_base}Encounter", json=er.json(), headers=local_hdr)

                # 5) POST observations into local server
                posted = 0
                for obs in new_obs:
                    resp = await local_cli.post(f"{local_base}Observation", json=obs, headers=local_hdr)
                    if resp.status_code in (200, 201):
                        posted += 1

                print(f"Mirrored {posted}/{len(new_obs)} obs for Patient {pid} at {now_iso()}")

            await asyncio.sleep(max(1, int(interval)))


# ---------------------------------------------------------------------------
# CLI entry – `python fhir_replay_server.py mirror --patients P1,P2 --interval 30`

if __name__ == "__main__":
    if len(sys.argv) >= 2 and sys.argv[1] == "mirror":
        import argparse

        ap = argparse.ArgumentParser("mirror")
        ap.add_argument("--patients", required=True, help="Comma-separated patient IDs from Epic sandbox")
        ap.add_argument("--interval", type=int, default=60, help="Polling interval seconds")
        args = ap.parse_args(sys.argv[2:])
        pids = [x.strip() for x in args.patients.split(",") if x.strip()]
        asyncio.run(mirror_task(pids, args.interval))
    else:
        # If launched via `uvicorn`, FastAPI app is already exposed as `app`
        print("Run via uvicorn or use: python fhir_replay_server.py mirror --patients P1,P2 --interval 60")
