"""
SmartStop Sentinel Connect – Module 2
Connector SDK & sample EHR adapter (Epic FHIR R4)
-------------------------------------------------
This module defines:
  • BaseConnector abstract class
  • ConnectorRegistry for hot-plug discovery
  • EpicFHIRAdapter demonstrating SMART-on-FHIR Bulk export
Any connector inheriting BaseConnector auto-registers in the registry and can be
launched by the Connector Manager side-car.
"""
from __future__ import annotations
from sqlalchemy.exc import IntegrityError
import abc
import importlib
from epic_token import EpicAppConfig, LabObservation, EpicPatient, DiagnosticReport
from db import SessionLocal
import time
import logging
import pkgutil
import threading
from datetime import datetime, timedelta
from typing import Any, Dict
from sqlalchemy.orm import declarative_base
from token_service import get_or_create_epic_token

import requests

Base = declarative_base()

LOGGER = logging.getLogger("sentinel.connectors")

class BaseConnector(abc.ABC):

    vendor: str = "unknown"
    version: str | None = None

    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.org_id: str = config.get("org_id", "demo_org")
        self.stop_flag = threading.Event()

    def start(self):
        """Spawn background thread to listen/poll vendor feed."""
        LOGGER.info("[%s] starting connector", self.vendor)
        self.stop_flag.clear()
        threading.Thread(target=self._run, daemon=True).start()

    def stop(self):
        LOGGER.info("[%s] stopping connector", self.vendor)
        self.stop_flag.set()

    @abc.abstractmethod
    def _run(self):
        """Blocking loop. Derived classes implement vendor-specific logic."""

class ConnectorRegistry:
    _registry: Dict[str, type[BaseConnector]] = {}

    @classmethod
    def register(cls, connector_cls: type[BaseConnector]):
        cls._registry[connector_cls.vendor.lower()] = connector_cls
        return connector_cls

    @classmethod
    def get(cls, vendor_name: str):
        return cls._registry.get(vendor_name.lower())

def connector(cls):
    return ConnectorRegistry.register(cls)

@connector
class EpicFHIRLabAdapter(BaseConnector):
    vendor = "EpicLab"
    version = "FHIR R4"

    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)
        self.fhir_base = config["fhir_base"].rstrip("/")
        self.client_id = config["client_id"]
        self.token_url = config["token_url"]
        self.org_id = config.get("org_id", self.org_id)

        self.access_token: str | None = None
        self.session = requests.Session()

    def _ensure_token(self):
        if self.access_token is None:
            self.access_token = get_or_create_epic_token(3)
            self.session.headers.update({"Authorization": f"Bearer {self.access_token}","Accept": "application/fhir+json"})

    def fetch_since(self, resource_type: str) -> list[dict]:
        self._ensure_token()

        results: list[dict] = []

        url = f"{self.fhir_base}/{resource_type}"

        while url and not self.stop_flag.is_set():
            resp = self.session.get(url, timeout=60)

            if resp.status_code in (401, 403):
                LOGGER.warning(
                    "[%s] token rejected (%s). refreshing and retrying once. body=%s",
                    self.vendor,
                    resp.status_code,
                    resp.text,
                )
                self.access_token = None
                self._ensure_token()
                resp = self.session.get(url, timeout=60)

            resp.raise_for_status()

            bundle = resp.json()
            entries = bundle.get("entry", [])

            for entry in entries:
                resource = entry.get("resource")
                if resource:
                    results.append(resource)

            next_url = None
            for link in bundle.get("link", []):
                if link.get("relation") == "next":
                    next_url = link.get("url")
                    break

            url = next_url
            params = None

        return results

    def _run(self):
        fhir_base = self.fhir_base
        org_id = self.org_id

        with SessionLocal() as db:
            patient_rows = (
                db.query(EpicPatient)
                .all()
            )

        patient_ids = [row.patient_id for row in patient_rows]

        resource_types: list[str] = []

        for pid in patient_ids:
            resource_types.append(
                f"Observation?category=laboratory&patient={pid}"
            )

            resource_types.append(
                f"DiagnosticReport?patient={pid}"
            )

        if not resource_types:
            bulk_types_cfg = self.config.get("bulk_types", "Observation,DiagnosticReport")
            resource_types = [t.strip() for t in bulk_types_cfg.split(",") if t.strip()]

        polling_interval = int(self.config.get("poll_minutes", 5))
        last_since = datetime.utcnow() - timedelta(minutes=polling_interval)

        while not self.stop_flag.is_set():
            try:
                LOGGER.info(
                    "[%s] polling types=%s since=%s",
                    self.vendor, ",".join(resource_types),
                    last_since.isoformat() + "Z",
                )

                for rtype in resource_types:
                    resources = self.fetch_since(rtype)

                    for resource in resources:
                        rsrc_type = resource.get("resourceType")

                        if rsrc_type == "OperationOutcome":
                            LOGGER.info(
                                "[%s] skipping OperationOutcome: %s",
                                self.vendor, resource.get("id")
                            )
                            continue

                        if rsrc_type == "Observation":
                            subject_ref = resource.get("subject", {}).get("reference")
                            # works for "Patient/xyz" and for absolute URLs
                            obs_patient_id = subject_ref.split("/")[-1] if subject_ref else None
                            fhir_id = resource.get("id")

                            if not fhir_id or not obs_patient_id:
                                LOGGER.debug(
                                    "[%s] skipping Observation with missing ids: id=%s patient=%s",
                                    self.vendor, fhir_id, obs_patient_id,
                                )
                                continue

                            should_emit = False
                            with SessionLocal() as db:
                                existing = (
                                    db.query(LabObservation)
                                    .filter(
                                        LabObservation.org_id == org_id,
                                        LabObservation.fhir_id == fhir_id,
                                        LabObservation.patient_id == obs_patient_id,
                                    )
                                    .first()
                                )

                                if existing:
                                    LOGGER.debug(
                                        "[%s] duplicate Observation, skipping (fhir_id=%s, patient_id=%s)",
                                        self.vendor, fhir_id, obs_patient_id,
                                    )
                                else:
                                    lab_row = lab_observation_from_fhir(org_id, resource)
                                    db.add(lab_row)
                                    try:
                                        db.commit()
                                        should_emit = True
                                    except IntegrityError:
                                        db.rollback()
                                        LOGGER.debug(
                                            "[%s] IntegrityError on Observation insert; treated as duplicate",
                                            self.vendor,
                                        )

                            if should_emit:
                                try:
                                    self.emit(
                                        {
                                            "source": "epic_fhir",
                                            "event_type": "Observation",
                                            "raw": resource,
                                            "org_id": org_id,
                                        }
                                    )
                                except Exception as exc:
                                    LOGGER.error(
                                        "[%s] emit failed for Observation: %s",
                                        self.vendor, exc,
                                    )

                            continue

                        if rsrc_type == "DiagnosticReport":
                            subject = resource.get("subject") or resource.get("patient") or {}
                            subject_ref = subject.get("reference")
                            dr_patient_id = subject_ref.split("/")[-1] if subject_ref else None
                            fhir_id = resource.get("id")

                            if not fhir_id or not dr_patient_id:
                                LOGGER.debug(
                                    "[%s] skipping DiagnosticReport with missing ids: id=%s patient=%s",
                                    self.vendor, fhir_id, dr_patient_id,
                                )
                                continue

                            should_emit = False
                            with SessionLocal() as db:
                                existing = (
                                    db.query(DiagnosticReport)
                                    .filter(
                                        DiagnosticReport.org_id == org_id,
                                        DiagnosticReport.fhir_id == fhir_id,
                                        DiagnosticReport.patient_id == dr_patient_id,
                                    )
                                    .first()
                                )

                                if existing:
                                    LOGGER.debug(
                                        "[%s] duplicate DiagnosticReport, skipping (fhir_id=%s, patient_id=%s)",
                                        self.vendor, fhir_id, dr_patient_id,
                                    )
                                else:
                                    dr_row = diagnostic_report_from_fhir(org_id, resource)
                                    db.add(dr_row)
                                    try:
                                        db.commit()
                                        should_emit = True
                                    except IntegrityError:
                                        db.rollback()
                                        LOGGER.debug(
                                            "[%s] IntegrityError on DiagnosticReport insert; treated as duplicate",
                                            self.vendor,
                                        )

                            if should_emit:
                                try:
                                    self.emit(
                                        {
                                            "source": "epic_fhir",
                                            "event_type": "DiagnosticReport",
                                            "raw": resource,
                                            "org_id": org_id,
                                        }
                                    )
                                except Exception as exc:
                                    LOGGER.error(
                                        "[%s] emit failed for DiagnosticReport: %s",
                                        self.vendor, exc,
                                    )

                            continue  # next resource

                        LOGGER.debug(
                            "[%s] unhandled FHIR resourceType=%s; skipping",
                            self.vendor, rsrc_type,
                        )

                last_since = datetime.utcnow()

            except Exception as exc:
                LOGGER.exception("[%s] error in incremental poll loop: %s", self.vendor, exc)

            self.stop_flag.wait(polling_interval * 60)

def parse_fhir_datetime(value: str | None):
    if not value:
        return None

    try:
        if value.endswith("Z"):
            return datetime.fromisoformat(value.replace("Z", "+00:00"))

        return datetime.fromisoformat(value)
    except ValueError:
        LOGGER.warning("Invalid FHIR datetime: %s", value)
        return None

def lab_observation_from_fhir(org_id: str, obs: dict) -> LabObservation:
    subject_ref = obs.get("subject", {}).get("reference")  # "Patient/xxx"
    patient_id = subject_ref.split("/")[1] if subject_ref else None

    encounter_ref = obs.get("encounter", {}).get("reference")  # "Encounter/yyy"
    encounter_id = encounter_ref.split("/")[1] if encounter_ref else None

    specimen_ref = obs.get("specimen", {}).get("reference")
    specimen_id = specimen_ref.split("/")[1] if specimen_ref else None

    # --- Primary category coding ---
    category_list = obs.get("category") or []
    if category_list and isinstance(category_list[0], dict):
        cat_codings = category_list[0].get("coding") or []
        if cat_codings and isinstance(cat_codings[0], dict):
            cat_coding = cat_codings[0]
        else:
            cat_coding = {}
    else:
        cat_coding = {}

    # --- Primary test coding (Observation.code.coding[0]) ---
    code_codings = obs.get("code", {}).get("coding") or []
    if code_codings and isinstance(code_codings[0], dict):
        code_coding = code_codings[0]
    else:
        code_coding = {}

    # Reference range (first element)
    ref_range = (obs.get("referenceRange") or [{}])[0]
    low = ref_range.get("low") or {}
    high = ref_range.get("high") or {}

    val_qty = obs.get("valueQuantity") or {}

    interp = (obs.get("interpretation") or [{}])[0]
    interp_coding_list = interp.get("coding") or []
    interp_coding = interp_coding_list[0] if interp_coding_list else {}

    return LabObservation(
        org_id=org_id,
        fhir_id=obs.get("id"),

        patient_id=patient_id,
        patient_display=obs.get("subject", {}).get("display"),

        encounter_id=encounter_id,
        encounter_display=obs.get("encounter", {}).get("display"),

        status=obs.get("status"),

        category_code=cat_coding.get("code"),
        category_display=cat_coding.get("display"),

        test_code_system=code_coding.get("system"),
        test_code=code_coding.get("code"),
        test_display=code_coding.get("display"),

        effective_datetime=parse_fhir_datetime(obs.get("effectiveDateTime")),
        issued_datetime=parse_fhir_datetime(obs.get("issued")),

        value_number=val_qty.get("value"),
        value_unit=val_qty.get("unit"),
        value_code=val_qty.get("code"),
        value_system=val_qty.get("system"),

        interpretation_code=interp_coding.get("code"),
        interpretation_text=interp.get("text"),

        ref_low=low.get("value"),
        ref_low_unit=low.get("unit"),
        ref_high=high.get("value"),
        ref_high_unit=high.get("unit"),

        specimen_id=specimen_id,
        specimen_display=obs.get("specimen", {}).get("display"),

        raw_fhir=obs,
    )

def diagnostic_report_from_fhir(org_id: str, dr: dict) -> DiagnosticReport:
    subject = dr.get("subject") or dr.get("patient") or {}
    if isinstance(subject, list):
        subject = subject[0] if subject else {}
    if not isinstance(subject, dict):
        subject = {}

    subject_ref = subject.get("reference")
    patient_id = subject_ref.split("/")[-1] if subject_ref else None

    # ---- Category (can be dict or list) ----
    cat = dr.get("category") or {}
    if isinstance(cat, list):
        cat = cat[0] if cat else {}
    if not isinstance(cat, dict):
        cat = {}

    cat_codings = cat.get("coding") or []
    if isinstance(cat_codings, list) and cat_codings and isinstance(cat_codings[0], dict):
        cat_coding = cat_codings[0]
    else:
        cat_coding = {}

    code = dr.get("code") or {}
    if isinstance(code, list):
        code = code[0] if code else {}
    if not isinstance(code, dict):
        code = {}

    code_text = code.get("text")
    code_codings = code.get("coding") or []
    if isinstance(code_codings, list) and code_codings and isinstance(code_codings[0], dict):
        code_coding = code_codings[0]
    else:
        code_coding = {}

    performer = dr.get("performer") or {}
    if isinstance(performer, list):
        performer = performer[0] if performer else {}
    if not isinstance(performer, dict):
        performer = {}

    performer_display = performer.get("display")
    performer_ref = performer.get("reference")

    result_refs: list[str] = []
    for r in dr.get("result") or []:
        if isinstance(r, dict):
            ref = r.get("reference")
            if ref:
                result_refs.append(ref)

    return DiagnosticReport(
        org_id=org_id,
        fhir_id=dr.get("id"),

        patient_id=patient_id,
        patient_display=subject.get("display"),

        status=dr.get("status"),

        category_code=cat_coding.get("code"),
        category_display=cat_coding.get("display"),

        code_system=code_coding.get("system"),
        code_code=code_coding.get("code"),
        code_display=code_coding.get("display") or code_text,
        code_text=code_text,

        effective_datetime=parse_fhir_datetime(dr.get("effectiveDateTime")),
        issued_datetime=parse_fhir_datetime(dr.get("issued")),

        performer_display=performer_display,
        performer_reference=performer_ref,

        result_references=result_refs,
        raw_fhir=dr,
    )

def autodiscover(pkg_name: str):
    package = importlib.import_module(pkg_name)
    pkg_path = getattr(package, "__path__", None)
    if not pkg_path:
        LOGGER.warning("autodiscover: %s is not a package, skipping", pkg_name)
        return
    for _, modname, ispkg in pkgutil.walk_packages(pkg_path, package.__name__ + "."):
        if not ispkg:
            importlib.import_module(modname)

def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

if __name__ == "__main__":
    autodiscover(__name__)

    with SessionLocal() as db:
        cfg_row: EpicAppConfig | None = (
            db.query(EpicAppConfig)
              .filter(EpicAppConfig.id == 3)
              .first()
        )

    if cfg_row is None:
        raise RuntimeError("EpicAppConfig with id=1 not found")

    config = {
        "client_id": cfg_row.client_id,
        "fhir_base": cfg_row.fhir_base,
        "token_url": cfg_row.token_url,
        "org_id": getattr(cfg_row, "org_id", "demo_org"),
        "poll_minutes": getattr(cfg_row, "poll_minutes", 5),
    }

    connector_cls = ConnectorRegistry.get("EpicLab")
    if connector_cls is None:
        raise RuntimeError("No connector registered for vendor 'EpicLab'")

    epic = connector_cls(config)
    epic.start()
    try:
        while True:
            time.sleep(1)
    except KeyboardInterrupt:
        epic.stop()
