# patient_create.py
import json
import random
import time
import uuid
from dataclasses import dataclass
from datetime import date, timedelta
from sqlalchemy.orm import Session
from db import SessionLocal          # adjust path if different
from epic_token import EpicPatient   # adjust path if different
from datetime import datetime, timezone
import requests

import jwt
import requests
from cryptography.hazmat.primitives import serialization

PAYER_MEMBER_ID_SYSTEM = "https://open.epic.com/FHIR/StructureDefinition/PayerMemberId"

# =========================
# CONFIG (edit these)
# =========================
CLIENT_ID = "1fed7724-fddf-469e-8b02-d11f338e264c"
TOKEN_URL = "https://fhir.epic.com/interconnect-fhir-oauth/oauth2/token"
FHIR_BASE = "https://fhir.epic.com/interconnect-fhir-oauth/api/FHIR/R4"

# This should match what your server expects for JWT header "kid".
# In many setups, "kid" is a short key id (e.g. "key1"), not a filename.
KID = "jwks_live_patients.json"

PRIVATE_KEY_PATH = "keys/privatekey.pem"

# Example identifier systems from your sample:
EPIC_MRN_SYSTEM_OID = "urn:oid:1.2.840.114350.1.13.861.1.7.5.737384.27000"
SSN_SYSTEM_OID = "urn:oid:2.16.840.1.113883.4.1"

# Optional: if your environment wants a GP reference, set it here.
DEFAULT_GP_REFERENCE = "https://hostname/instance/api/FHIR/R4/Practitioner/e4ZKPXZ0ux4Z80UHywzHyHg3"
DEFAULT_GP_DISPLAY = "Sample Provider, MD"


# =========================
# AUTH (same pattern as yours)
# =========================
with open(PRIVATE_KEY_PATH, "rb") as fp:
    KEY = serialization.load_pem_private_key(fp.read(), password=None)


def client_assertion() -> str:
    now = int(time.time())
    payload = {
        "iss": CLIENT_ID,
        "sub": CLIENT_ID,
        "aud": TOKEN_URL,
        "jti": str(uuid.uuid4()),
        "iat": now,
        "nbf": now,
        "exp": now + 300,
    }

    headers = {"kid": KID}

    return jwt.encode(payload, KEY, algorithm="RS384", headers=headers)


def get_token() -> dict:
    assertion = client_assertion()
    r = requests.post(
        TOKEN_URL,
        data={
            "grant_type": "client_credentials",
            "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
            "client_assertion": assertion,
        },
        timeout=15,
    )
    r.raise_for_status()
    return r.json()


# =========================
# RANDOM PATIENT GENERATION
# =========================
FIRST_NAMES_F = ["Test", "Ava", "Mia", "Sophia", "Olivia", "Emma", "Noor", "Aisha", "Yuki", "Hana"]
FIRST_NAMES_M = ["Test", "Liam", "Noah", "Ethan", "Lucas", "Arjun", "Kabir", "Kenji", "Hiro", "Ravi"]
LAST_NAMES = ["Patient", "Smith", "Johnson", "Patel", "Khan", "Garcia", "Brown", "Williams", "Tanaka", "Singh"]

CITIES = [
    ("Verona", "WI", "53593", "US"),
    ("Madison", "WI", "53703", "US"),
    ("Austin", "TX", "78701", "US"),
    ("Seattle", "WA", "98101", "US"),
    ("Raleigh", "NC", "27601", "US"),
]

STREET_LINES = [
    ["100 Milky Way", "Learning Campus"],
    ["42 Galaxy Blvd", "Suite 12"],
    ["7 Maple Street"],
    ["18 River Road", "Apt 4B"],
    ["901 Sunset Ave"],
]

LANG_CODES = [
    ("ja", "Japanese"),
    ("en", "English"),
    ("es", "Spanish"),
    ("hi", "Hindi"),
]


def random_birthdate(min_age: int = 1, max_age: int = 90) -> str:
    today = date.today()
    # pick an age, then a day offset within that year
    age = random.randint(min_age, max_age)
    base = today - timedelta(days=age * 365)
    offset = random.randint(0, 364)
    return (base - timedelta(days=offset)).isoformat()


def random_phone_us() -> str:
    # purely synthetic; not guaranteed unassigned
    a = random.randint(200, 999)
    b = random.randint(200, 999)
    c = random.randint(1000, 9999)
    return f"{a}-{b}-{c}"


def random_email(first: str, last: str) -> str:
    domains = ["example.com", "example.test", "synthetic.invalid"]
    return f"{first.lower()}.{last.lower()}{random.randint(1,9999)}@{random.choice(domains)}"


def random_suffix() -> str | None:
    return random.choice([None, "JR.", "SR.", "III"])


def fake_ssn_like() -> str:
    # Deliberately non-realistic; keeps "999" prefix to avoid matching real SSNs
    return f"999-{random.randint(10,99):02d}-{random.randint(1000,9999):04d}"


def build_random_patient_payload(*, include_gp: bool = False) -> dict:
    gender = random.choice(["female", "male"])   # <-- ONLY these

    if gender == "female":
        first = random.choice(FIRST_NAMES_F)
    else:
        first = random.choice(FIRST_NAMES_M)

    last = random.choice(LAST_NAMES)

    city, state, postal, _country = random.choice(CITIES)
    lines = random.choice(STREET_LINES)

    patient = {
        "resourceType": "Patient",
        "identifier": [
            {"use": "usual", "system": EPIC_MRN_SYSTEM_OID, "value": str(random.randint(1000, 999999))},
            {"use": "usual", "system": SSN_SYSTEM_OID, "value": fake_ssn_like()},
        ],
        "name": [{"use": "official", "family": last, "given": [first]}],
        "telecom": [
            {"system": "phone", "value": random_phone_us(), "use": "home"},
            {"system": "email", "value": random_email(first, last)},
        ],
        "gender": gender,
        "birthDate": random_birthdate(min_age=1, max_age=90),
        "address": [
            {
                "use": "home",
                "line": lines,
                "city": city,
                "state": state,
                "postalCode": postal,
                "country": "US",   # <-- was "USA"
            }
        ],
        "maritalStatus": {"text": random.choice(["Single", "Married", "Divorced", "Widowed"])},
        "extension": [
            {
                "url": "http://open.epic.com/FHIR/R4/StructureDefinition/patient-preferred-provider-sex",
                "valueCode": gender,  # <-- male/female only
            }
        ],
    }

    # If you REALLY want generalPractitioner, only do this if it exists in THIS tenant:
    if include_gp:
        patient["generalPractitioner"] = [
            {"display": DEFAULT_GP_DISPLAY, "reference": "Practitioner/e4ZKPXZ0ux4Z80UHywzHyHg3"}
        ]

    return patient

def parse_birth_date(birth_date_str: str | None) -> date | None:
    if not birth_date_str:
        return None
    try:
        return date.fromisoformat(birth_date_str)  # YYYY-MM-DD
    except ValueError:
        return None


def save_patient_to_db(session: Session, patient_resource: dict) -> "EpicPatient":
    patient_id = patient_resource.get("id")
    if not patient_id:
        raise ValueError("Patient resource has no 'id'")

    identifiers = patient_resource.get("identifier") or []
    telecoms = patient_resource.get("telecom") or []
    addresses = patient_resource.get("address") or []
    names = patient_resource.get("name") or []

    # ---- Name ----
    official_name = next((n for n in names if n.get("use") == "official"), names[0] if names else {})
    full_name = official_name.get("text")
    given_list = official_name.get("given") or []
    first_name = given_list[0] if given_list else None
    last_name = official_name.get("family")

    if not full_name:
        parts = [p for p in [first_name, last_name] if p]
        full_name = " ".join(parts) if parts else None

    # ---- Phones ----
    home_phone = None
    work_phone = None
    for t in telecoms:
        if t.get("system") != "phone":
            continue
        use = t.get("use")
        if use == "home" and not home_phone:
            home_phone = t.get("value")
        elif use == "work" and not work_phone:
            work_phone = t.get("value")

    # ---- Home address ----
    home_addr = next((a for a in addresses if a.get("use") == "home"), addresses[0] if addresses else {})
    lines = home_addr.get("line") or []
    address_line = lines[0] if lines else None
    city = home_addr.get("city")
    state = home_addr.get("state")
    postal_code = home_addr.get("postalCode")
    country = home_addr.get("country")

    # ---- Identifiers ----
    epic_id = None          # model: identifier.type.text == "EPIC" OR MRN OID (your generator)
    epi_id = None           # model: identifier.type.text == "EPI"
    ssn = None              # model: SSN system OID
    payer_member_id = None  # model: first PayerMemberId
    ceid = None             # model: type "CEID"
    mychart_login = None    # model: type "MYCHARTLOGIN"

    for ident in identifiers:
        system = ident.get("system")
        value = ident.get("value")
        id_type_text = (ident.get("type") or {}).get("text")

        # Prefer type.text when present
        if id_type_text == "EPIC" and value and not epic_id:
            epic_id = value
            continue
        if id_type_text == "EPI" and value and not epi_id:
            epi_id = value
            continue
        if id_type_text == "CEID" and value and not ceid:
            ceid = value
            continue
        if id_type_text == "MYCHARTLOGIN" and value and not mychart_login:
            mychart_login = value
            continue

        # Fallback mapping by system (works with your random payload)
        if system == EPIC_MRN_SYSTEM_OID and value and not epic_id:
            epic_id = value  # <-- MRN stored into epic_id
        elif system == SSN_SYSTEM_OID and value and not ssn:
            ssn = value
        elif system == PAYER_MEMBER_ID_SYSTEM and value and not payer_member_id:
            payer_member_id = value

    gender = patient_resource.get("gender")
    birth_date = parse_birth_date(patient_resource.get("birthDate"))

    raw_json = json.dumps(patient_resource)

    # ---- Upsert on patient_id ----
    patient = session.query(EpicPatient).filter_by(patient_id=patient_id).one_or_none()
    if not patient:
        patient = EpicPatient(patient_id=patient_id)

    patient.epic_id = epic_id
    patient.epi_id = epi_id
    patient.ssn = ssn
    patient.payer_member_id = payer_member_id
    patient.ceid = ceid
    patient.mychart_login = mychart_login

    patient.first_name = first_name
    patient.last_name = last_name
    patient.full_name = full_name
    patient.gender = gender
    patient.birth_date = birth_date

    patient.home_phone = home_phone
    patient.work_phone = work_phone

    patient.address_line = address_line
    patient.city = city
    patient.state = state
    patient.postal_code = postal_code
    patient.country = country

    patient.raw_json = raw_json

    session.add(patient)
    session.commit()
    session.refresh(patient)
    return patient

# =========================
# FHIR POST
# =========================
@dataclass
class CreateResult:
    status_code: int
    location: str | None
    resource: dict | None
    text: str


def create_patient(access_token: str, patient_payload: dict) -> CreateResult:
    headers = {
        "Authorization": f"Bearer {access_token}",
        "Accept": "application/fhir+json",
        "Content-Type": "application/fhir+json",
        "Prefer": "return=representation",
    }

    url = f"{FHIR_BASE}/Patient"
    resp = requests.post(url, headers=headers, json=patient_payload, timeout=30)

    location = resp.headers.get("Location") or resp.headers.get("location")

    resource = None
    text = resp.text
    try:
        resource = resp.json()
        text = json.dumps(resource, indent=2)
    except Exception:
        pass

    return CreateResult(
        status_code=resp.status_code,
        location=location,
        resource=resource,
        text=text,
    )

def get_patient_by_url(access_token: str, url: str) -> dict:
    headers = {
        "Authorization": f"Bearer {access_token}",
        "Accept": "application/fhir+json",
    }
    resp = requests.get(url, headers=headers, timeout=30)
    resp.raise_for_status()
    return resp.json()


def get_patient_by_id(access_token: str, patient_id: str) -> dict:
    return get_patient_by_url(access_token, f"{FHIR_BASE}/Patient/{patient_id}")

def file_many_flowsheet_readings(access_token: str, *, patient_id: str, encounter_ref: str, readings: dict) -> dict:
    effective_dt = datetime.now(timezone.utc).isoformat()
    out = {}
    for name, cfg in readings.items():
        out[name] = file_flowsheet_reading(
            access_token,
            patient_id=patient_id,
            encounter_ref=encounter_ref,
            flowsheet_system=cfg["system"],
            flowsheet_row_id=cfg["row"],
            value=float(cfg["value"]),
            unit=cfg.get("unit"),
            effective_dt=effective_dt,
        )
    return out

def file_flowsheet_reading(
    access_token: str,
    *,
    patient_id: str,
    encounter_ref: str,      # e.g. "Encounter/efg123..."  (REQUIRED in your tenant)
    flowsheet_system: str,
    flowsheet_row_id: str,
    value: float,
    unit: str | None = None,
    effective_dt: str | None = None,
) -> dict:
    if effective_dt is None:
        effective_dt = datetime.now(timezone.utc).isoformat()

    obs = {
        "resourceType": "Observation",
        "status": "final",
        "category": [{
            "coding": [{
                "system": "http://terminology.hl7.org/CodeSystem/observation-category",
                "code": "vital-signs",
            }]
        }],
        "code": {"coding": [{"system": flowsheet_system, "code": flowsheet_row_id}]},
        "subject": {"reference": f"Patient/{patient_id}"},
        "encounter": {"reference": encounter_ref},   # ✅ REQUIRED (your error 59108)
        "effectiveDateTime": effective_dt,
        "valueQuantity": {"value": value, **({} if unit is None else {"unit": unit})},
    }

    headers = {
        "Authorization": f"Bearer {access_token}",
        "Accept": "application/fhir+json",
        "Content-Type": "application/fhir+json",
        "Prefer": "return=representation",
    }

    r = requests.post(f"{FHIR_BASE}/Observation", headers=headers, json=obs, timeout=30)
    if not r.ok:
        try:
            err = r.json()
        except Exception:
            err = {"raw": r.text}
        raise RuntimeError(f"Observation.Create failed ({r.status_code}): {err}")

    return r.json()

def main():
    tok = get_token()
    access_token = tok["access_token"]

    patient_payload = build_random_patient_payload(include_gp=True)

    print("Posting Patient payload:\n", json.dumps(patient_payload, indent=2))
    result = create_patient(access_token, patient_payload)

    print("\n=== Create Patient Result ===")
    print("HTTP:", result.status_code)
    print("Location:", result.location)
    print("Body:\n", result.text)

    # 1) Determine the created Patient resource (may be partial)
    created_resource = result.resource if isinstance(result.resource, dict) else None

    # 2) If no id returned, try to GET via Location header
    if not created_resource or not created_resource.get("id"):
        if result.location:
            # Epic may return absolute Location; use directly
            created_resource = get_patient_by_url(access_token, result.location)

    # 3) If still not available but we have id, GET by id
    if created_resource and created_resource.get("id") and created_resource.get("resourceType") != "Patient":
        created_resource = get_patient_by_id(access_token, created_resource["id"])

    if not created_resource or created_resource.get("resourceType") != "Patient" or not created_resource.get("id"):
        print("\nCould not resolve created Patient resource; skipping DB save.")
        return

    created_id = created_resource["id"]

    # You MUST get these from the Epic tenant/customer:
    # - flowsheet_system (customer-specific code.coding.system)
    # - flowsheet_row_id (code.coding.code)
    # Epic documents these as required inputs. :contentReference[oaicite:4]{index=4}
    readings = {
        "MAP":  {"system": "CUSTOM_FLOWSHEET_SYSTEM", "row": "ROW_ID_FOR_MAP",  "value": 62,  "unit": "mm[Hg]"},
        "HR":   {"system": "CUSTOM_FLOWSHEET_SYSTEM", "row": "ROW_ID_FOR_HR",   "value": 118, "unit": "/min"},
        "SpO2": {"system": "CUSTOM_FLOWSHEET_SYSTEM", "row": "ROW_ID_FOR_SPO2", "value": 91,  "unit": "%"},
        "CVP":  {"system": "CUSTOM_FLOWSHEET_SYSTEM", "row": "ROW_ID_FOR_CVP",  "value": 14,  "unit": "mm[Hg]"},
    }

    # responses = file_many_flowsheet_readings(access_token, patient_id=created_id, encounter_ref='ewYPWmestytGzRSBFlutk7w3', readings=readings)
    # print("Filed flowsheet readings:", list(responses.keys()))

    print("\nCreated Patient id:", created_id)
    print("GET URL:", f"{FHIR_BASE}/Patient/{created_id}")

    # 4) Save to DB
    with SessionLocal() as session:
        saved = save_patient_to_db(session, created_resource)
        print("\nSaved to epic_patients:")
        print("  db row id:", saved.id)
        print("  patient_id:", saved.patient_id)
        print("  epic_id (MRN):", saved.epic_id)

if __name__ == "__main__":
    main()
