import jwt, uuid, requests, time
from cryptography.hazmat.primitives import serialization
import json
from datetime import datetime
from blood_db import init_db

from sqlalchemy.orm import Session
from db import SessionLocal
from epic_token import EpicPatient
from observation_store import store_observation_bundle

CLIENT_ID  = "1fed7724-fddf-469e-8b02-d11f338e264c"
TOKEN_URL  = "https://fhir.epic.com/interconnect-fhir-oauth/oauth2/token"
FHIR_BASE  = "https://fhir.epic.com/interconnect-fhir-oauth/api/FHIR/R4"
KID        = "jwks_live_patients.json"

with open("keys/privatekey.pem", "rb") as fp:
    KEY = serialization.load_pem_private_key(fp.read(), password=None)

def client_assertion():
    now = int(time.time())
    payload = {
        "iss": CLIENT_ID,
        "sub": CLIENT_ID,
        "aud": TOKEN_URL,
        "jti": str(uuid.uuid4()),
        "iat": now,
        "nbf": now,
        "exp": now + 300,
    }

    headers = {"kid": KID}

    token = jwt.encode(
        payload,
        KEY,
        algorithm="RS384",
        headers=headers,
    )
    return token

def get_token():
    assertion = client_assertion()
    print("Client assertion:\n", assertion)

    r = requests.post(
        TOKEN_URL,
        data={
            "grant_type": "client_credentials",
            "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
            "client_assertion": assertion,
        },
        timeout=15,
    )
    r.raise_for_status()
    return r.json()

def check_metadata(access_token: str):
    headers = {
        "Authorization": f"Bearer {access_token}",
        "Accept": "application/fhir+json",
    }
    resp = requests.get(f"{FHIR_BASE}/metadata", headers=headers, timeout=30)
    resp.raise_for_status()
    data = resp.json()
    impl_desc = data.get("implementation", {}).get("description")

def parse_birth_date(birth_date_str):
    if not birth_date_str:
        return None
    try:
        return datetime.strptime(birth_date_str, "%Y-%m-%d").date()
    except ValueError:
        return None

def save_patient_to_db(session: Session, patient_resource: dict) -> EpicPatient:
    patient_id = patient_resource.get("id")
    if not patient_id:
        raise ValueError("Patient resource has no 'id'")

    identifiers = patient_resource.get("identifier", []) or []
    telecoms = patient_resource.get("telecom", []) or []
    addresses = patient_resource.get("address", []) or []
    names = patient_resource.get("name", []) or []

    # pick an official name if present
    official_name = next((n for n in names if n.get("use") == "official"), names[0] if names else {})
    full_name = official_name.get("text")
    given_list = official_name.get("given") or []
    first_name = given_list[0] if given_list else None
    last_name = official_name.get("family")

    # phones
    home_phone = None
    work_phone = None
    for t in telecoms:
        if t.get("system") == "phone":
            if t.get("use") == "home" and not home_phone:
                home_phone = t.get("value")
            elif t.get("use") == "work" and not work_phone:
                work_phone = t.get("value")

    # home address
    home_addr = next((a for a in addresses if a.get("use") == "home"), addresses[0] if addresses else {})
    address_line = None
    if home_addr:
        lines = home_addr.get("line") or []
        address_line = lines[0] if lines else None
    city = home_addr.get("city") if home_addr else None
    state = home_addr.get("state") if home_addr else None
    postal_code = home_addr.get("postalCode") if home_addr else None
    country = home_addr.get("country") if home_addr else None

    # identifiers
    epic_id = None
    epi_id = None
    ssn = None
    payer_member_id = None
    ceid = None
    mychart_login = None

    for ident in identifiers:
        system = ident.get("system")
        id_type = (ident.get("type") or {}).get("text")
        value = ident.get("value")

        if id_type == "EPIC":
            epic_id = value
        elif id_type == "EPI":
            epi_id = value
        elif id_type == "CEID":
            ceid = value
        elif id_type == "MYCHARTLOGIN":
            mychart_login = value
        elif system == "urn:oid:2.16.840.1.113883.4.1":  # SSN
            ssn = value
        elif system == "https://open.epic.com/FHIR/StructureDefinition/PayerMemberId" and not payer_member_id:
            payer_member_id = value

    gender = patient_resource.get("gender")
    birth_date = parse_birth_date(patient_resource.get("birthDate"))

    raw_json = json.dumps(patient_resource)

    # upsert on patient_id
    patient = session.query(EpicPatient).filter_by(patient_id=patient_id).one_or_none()
    if not patient:
        patient = EpicPatient(patient_id=patient_id)

    patient.epic_id = epic_id
    patient.epi_id = epi_id
    patient.ssn = ssn
    patient.payer_member_id = payer_member_id
    patient.ceid = ceid
    patient.mychart_login = mychart_login

    patient.first_name = first_name
    patient.last_name = last_name
    patient.full_name = full_name
    patient.gender = gender
    patient.birth_date = birth_date

    patient.home_phone = home_phone
    patient.work_phone = work_phone

    patient.address_line = address_line
    patient.city = city
    patient.state = state
    patient.postal_code = postal_code
    patient.country = country

    patient.raw_json = raw_json

    session.add(patient)
    session.commit()
    session.refresh(patient)
    return patient

def get_patient_resource(access_token: str, patient_id: str) -> dict:
    headers = {
        "Authorization": f"Bearer {access_token}",
        "Accept": "application/fhir+json",
    }
    url = f"{FHIR_BASE}/{patient_id}"
    resp = requests.get(url, headers=headers, timeout=30)
    resp.raise_for_status()
    data = resp.json()
    return data

OBS_CATEGORIES = [
    "laboratory",
    "vital-signs",
    "imaging",
    "social-history",
    "survey",
    "exam",
    "therapy",
    "activity",
]

def get_observations_bundle(access_token: str, patient_id: str, *, category: str | None, count: int = 100) -> dict:
    headers = {
        "Authorization": f"Bearer {access_token}",
        "Accept": "application/fhir+json",
    }
    url = f"{FHIR_BASE}/Observation"

    params = {
        "patient": patient_id,
        "_count": count,
    }
    if category is not None:
        params["category"] = category  # e.g. "vital-signs"

    resp = requests.get(url, headers=headers, params=params, timeout=30)
    data = resp.json()

    resp.raise_for_status()
    return data

def get_all_observations_by_categories(access_token: str, patient_id: str) -> list[dict]:
    bundles = []
    for cat in OBS_CATEGORIES:
        try:
            b = get_observations_bundle(access_token, patient_id, category=cat)
            bundles.append(b)
        except requests.HTTPError as e:
            # some categories may not be supported in a given sandbox/tenant
            print(f"Skipping category={cat} due to error: {e}")
    return bundles

def get_all_observations_bundles(access_token: str, first_bundle: dict) -> list[dict]:
    """
    Follows Bundle.link[relation='next'] and returns a list of Bundles (all pages).
    Useful when the server paginates results.
    """
    headers = {
        "Authorization": f"Bearer {access_token}",
        "Accept": "application/fhir+json",
    }

    bundles = [first_bundle]

    def find_next_url(bundle: dict) -> str | None:
        for link in bundle.get("link") or []:
            if link.get("relation") == "next":
                return link.get("url")
        return None

    next_url = find_next_url(first_bundle)

    while next_url:
        resp = requests.get(next_url, headers=headers, timeout=30)
        resp.raise_for_status()
        b = resp.json()
        bundles.append(b)
        next_url = find_next_url(b)

    return bundles

def import_list_patients(access_token: str, code: str, identifier: str):
    # 1) get the List bundle
    bundle = search_patient_by_identifier(access_token, code, identifier)
    entries = bundle.get("entry", [])
    if not entries:
        print("No List resources found for given code/identifier")
        return

    # take the first List in the Bundle
    list_resource = entries[0]["resource"]
    patient_entries = list_resource.get("entry", []) or []

    saved_patients = []

    with SessionLocal() as session:
        for idx, e in enumerate(patient_entries, start=1):
            ref = e["item"]["reference"]  # e.g. "Patient/ey.LPWLX90cwuwABvUYkrfw3"
            display = e["item"].get("display")
            print(f"[{idx}] Fetching {ref} ({display})")

            patient_resource = get_patient_resource(access_token, ref)
            saved = save_patient_to_db(session, patient_resource)
            bundles = get_all_observations_by_categories(access_token, saved.patient_id)
            for b in bundles:
                store_observation_bundle(session, "Retro Industries", b)
            saved_patients.append(saved.patient_id)

    print("Imported patients:", saved_patients)

def search_patient_by_identifier(access_token: str, code: str, identifier: str):
    headers = {
        "Authorization": f"Bearer {access_token}",
        "Accept": "application/fhir+json",
    }
    url = f"{FHIR_BASE}/List?code={code}&identifier={identifier}"
    resp = requests.get(url, headers=headers, timeout=30)
    resp.raise_for_status()
    return resp.json()

def start_bulk_export(
    access_token: str,
    level: str = "system",  # "system" or "patient"
    resource_types: list[str] | None = None,
) -> str:
    """
    Kick off Epic Bulk Data export at system or Patient level.

    level == "system": GET {FHIR_BASE}/$export
    level == "patient": GET {FHIR_BASE}/Patient/$export
    """
    headers = {
        "Authorization": f"Bearer {access_token}",
        "Prefer": "respond-async",
        "Accept": "application/fhir+json",
    }

    params: dict[str, str] = {}
    if resource_types:
        params["_type"] = ",".join(resource_types)

    if level == "system":
        path = "$export"
    elif level == "patient":
        path = "Patient/$export"
    else:
        raise ValueError("level must be 'system' or 'patient'")

    url = f"{FHIR_BASE}/{path}"
    resp = requests.get(url, headers=headers, params=params, timeout=60)

    if not resp.ok:
        print("Status:", resp.status_code)
        try:
            print("Body JSON:", resp.json())
        except Exception:
            print("Body text:", resp.text)
        resp.raise_for_status()

    job_url = resp.headers.get("Content-Location") or resp.headers.get("content-location")
    if not job_url:
        raise RuntimeError("No Content-Location header returned from $export")

    print("Job URL:", job_url)
    return job_url

def wait_for_export_manifest(access_token: str,
                             job_url: str,
                             poll_interval: int = 5,
                             max_wait_seconds: int = 900) -> dict:
    """
    Poll the job URL until we get a 200 and return the manifest JSON.
    """
    headers = {
        "Authorization": f"Bearer {access_token}",
        "Accept": "application/fhir+json",
    }

    start = time.time()
    while True:
        resp = requests.get(job_url, headers=headers, timeout=60)

        if resp.status_code == 202:
            # still processing (like curl -f failing)
            print("…job still processing")
        elif resp.status_code == 200:
            manifest = resp.json()
            print("Export job completed, got manifest.")
            # If you want to persist it:
            # with open("manifest.json", "w") as fp:
            #     json.dump(manifest, fp, indent=2)
            return manifest
        else:
            # Mimic curl -f error behavior
            raise RuntimeError(f"Export job failed: {resp.status_code} {resp.text}")

        if time.time() - start > max_wait_seconds:
            raise TimeoutError("Timed out waiting for bulk export manifest")

        time.sleep(poll_interval)

def list_manifest_outputs(manifest: dict):
    """
    Print type + url like: jq '.output[] | {type, url}'
    """
    outputs = manifest.get("output", [])
    for o in outputs:
        print({"type": o.get("type"), "url": o.get("url")})
    return outputs


def fetch_first_output_lines(access_token: str,
                             manifest: dict,
                             max_lines: int = 5):
    """
    Fetch the first NDJSON URL and print the first few lines.
    """
    outputs = manifest.get("output", [])
    if not outputs:
        print("No output entries in manifest.")
        return

    first_url = outputs[0].get("url")
    if not first_url:
        print("First output entry has no URL.")
        return

    print("First output URL:", first_url)

    headers = {
        "Authorization": f"Bearer {access_token}",
    }
    # stream NDJSON
    with requests.get(first_url, headers=headers, stream=True, timeout=300) as r:
        r.raise_for_status()
        for i, line in enumerate(r.iter_lines(decode_unicode=True), start=1):
            if not line:
                continue
            print(line)
            if i >= max_lines:
                break

if __name__ == "__main__":

    #init_db()

    tok = get_token()
    access_token = tok["access_token"]

    # List search: Cardiac ICU example you showed
    code = "patients"
    identifier = "urn:oid:1.2.840.114350.1.13.0.1.7.2.806567|5332"

    import_list_patients(access_token, code, identifier)

