from datetime import datetime, timedelta, timezone
import secrets
import time
import uuid
from typing import List, Optional

import requests
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from jose import jwt
from pydantic import BaseModel, EmailStr, constr
from sqlalchemy import Column, DateTime, ForeignKey, String, Table
from sqlalchemy.orm import Session, relationship
from cryptography.hazmat.primitives import serialization

from db import SessionLocal, Base, engine
from epic_token import EpicToken, EpicAppConfig

############################################################
# CONFIG (for *your* auth service JWTs)
############################################################
JWT_ISS = "smartstop.io"
JWT_EXP_MINUTES = 30
JWT_ALG = "RS256"
PRIVATE_KEY_PATH = "keys/privatekey.pem"
KID        = "jwks_live_patients.json"

DEFAULT_LIFETIME = 3600  # seconds
SAFETY_MARGIN = 300      # seconds

with open(PRIVATE_KEY_PATH, "r") as fp:
    PRIVATE_KEY = fp.read()

with open("keys/privatekey.pem", "rb") as fp:
    KEY = serialization.load_pem_private_key(fp.read(), password=None)

security = HTTPBasic()
app = FastAPI(title="Sentinel Connect – Token Service")

# 👇 CORS CONFIG (for React on localhost:3000)
origins = [
    "http://localhost:3000",
    "http://127.0.0.1:3000",
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,      # for dev you can temporarily use ["*"]
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

############################################################
# SQLALCHEMY MODELS (local to this service)
############################################################

client_scope = Table(
    "client_scope", Base.metadata,
    Column("client_id", ForeignKey("clients.id"), primary_key=True),
    Column("scope_id", ForeignKey("scopes.id"), primary_key=True),
)


class Client(Base):
    __tablename__ = "clients"
    id = Column(String(64), primary_key=True, index=True)       # client_id
    secret = Column(String(128), nullable=False)                # client_secret (for *this* service)
    org = Column(String(255), nullable=False)                   # org name
    contact = Column(String(255), nullable=False)               # contact email
    created_at = Column(DateTime, default=datetime.now())
    scopes = relationship("Scope", secondary=client_scope, back_populates="clients")


class Scope(Base):
    __tablename__ = "scopes"
    id = Column(String(64), primary_key=True)
    name = Column(String(128), unique=True, index=True)
    clients = relationship("Client", secondary=client_scope, back_populates="scopes")


# Create all tables (clients, scopes, EpicToken, EpicAppConfig, etc.)
Base.metadata.create_all(bind=engine)

############################################################
# Pydantic Schemas
############################################################

class RegisterIn(BaseModel):
    org_name: constr(strip_whitespace=True, min_length=2)
    contact_email: EmailStr
    scopes: List[constr(pattern=r"^[a-z0-9_.]+$")]


class RegisterOut(BaseModel):
    client_id: str
    client_secret: str


class TokenIn(BaseModel):
    scopes: List[str]


class TokenOut(BaseModel):
    access_token: str
    token_type: str = "Bearer"
    expires_in: int

############################################################
# DEPENDENCIES
############################################################

def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

############################################################
# HELPERS (for your own auth service JWTs)
############################################################

def create_client(db: Session, org: str, email: str, scopes: List[str]) -> Client:
    client_id = secrets.token_urlsafe(16)
    client_secret = secrets.token_urlsafe(32)
    c = Client(id=client_id, secret=client_secret, org=org, contact=email)
    for s in scopes:
        scope_obj = db.query(Scope).filter_by(name=s).first()
        if not scope_obj:
            scope_obj = Scope(id=secrets.token_hex(8), name=s)
        c.scopes.append(scope_obj)
    db.add(c)
    db.commit()
    db.refresh(c)
    return c


def issue_jwt(client_id: str, scopes: List[str]) -> str:
    now = datetime.now()
    payload = {
        "iss": JWT_ISS,
        "sub": client_id,
        "iat": int(now.timestamp()),
        "exp": int((now + timedelta(minutes=JWT_EXP_MINUTES)).timestamp()),
        "scope": " ".join(scopes),
    }
    token = jwt.encode(payload, PRIVATE_KEY, algorithm=JWT_ALG)
    return token

############################################################
# API ENDPOINTS (for your auth service)
############################################################

@app.post("/register", response_model=RegisterOut, status_code=status.HTTP_201_CREATED)
def register(body: RegisterIn, db: Session = Depends(get_db)):
    client = create_client(db, body.org_name, body.contact_email, body.scopes)
    return RegisterOut(client_id=client.id, client_secret=client.secret)


@app.post("/token", response_model=TokenOut)
def token(
    body: TokenIn,
    creds: HTTPBasicCredentials = Depends(security),
    db: Session = Depends(get_db),
):
    client = db.query(Client).filter_by(id=creds.username).first()
    if not client or client.secret != creds.password:
        raise HTTPException(status_code=401, detail="invalid_credentials")

    if not set(body.scopes).issubset({s.name for s in client.scopes}):
        raise HTTPException(status_code=403, detail="scope_denied")

    jwt_token = issue_jwt(client.id, body.scopes)
    return TokenOut(access_token=jwt_token, expires_in=JWT_EXP_MINUTES * 60)


@app.get("/scopes", response_model=List[str])
def list_scopes(db: Session = Depends(get_db)):
    return [s.name for s in db.query(Scope).all()]

############################################################
# EPIC TOKEN CACHE HELPERS (EpicToken table)
############################################################

def _get_stored_token(
    db: Session,
    client_id: str,
    fhir_base: str,
) -> Optional[str]:
    token_row: EpicToken | None = (
        db.query(EpicToken)
        .filter(
            EpicToken.client_id == client_id,
            EpicToken.fhir_base == fhir_base,
        )
        .order_by(EpicToken.id.desc())
        .first()
    )

    if not token_row or not token_row.token_timestamp:
        return None

    expires_in = token_row.expires_in or DEFAULT_LIFETIME

    issued_at = token_row.token_timestamp
    now = datetime.now()

    # still valid?
    if issued_at + timedelta(seconds=expires_in - SAFETY_MARGIN) > now:
        return token_row.new_token

    return None


def _store_token(
    db: Session,
    client_id: str,
    fhir_base: str,
    token_url: str,
    org_id: str,
    access_token: str,
    expires_in: int | None,
    refresh_token: str | None = None,
):
    token_row = EpicToken(
        client_id=client_id,
        refresh_token=refresh_token,
        new_token=access_token,
        fhir_base=fhir_base,
        token_url=token_url,
        org_id=org_id,
        token_timestamp=datetime.now(),
        expires_in=expires_in or DEFAULT_LIFETIME,
    )
    db.add(token_row)
    db.commit()

############################################################
# EPIC CONFIG + JWT CLIENT_ASSERTION (EpicAppConfig table)
############################################################

def _get_epic_config_from_db(
    db: Session,
    client_id: str,
    fhir_base: str,
    org_id: Optional[str],
) -> EpicAppConfig:
    """
    Load Epic backend-services config for this client from DB.
    Adjust filters if your EpicAppConfig model uses different keys.
    """
    query = db.query(EpicAppConfig).filter(
        EpicAppConfig.client_id == client_id,
        EpicAppConfig.fhir_base == fhir_base,
    )

    if org_id is not None:
        query = query.filter(EpicAppConfig.org_id == org_id)

    cfg = query.first()
    if not cfg:
        raise RuntimeError(
            "EpicAppConfig not found for client_id=%r fhir_base=%r org_id=%r"
            % (client_id, fhir_base, org_id)
        )

    return cfg


def _build_epic_client_assertion(
    cfg: EpicAppConfig,
):
    now = int(time.time())
    payload = {
        "iss": cfg.client_id,
        "sub": cfg.client_id,
        "aud": cfg.token_url,
        "jti": str(uuid.uuid4()),
        "iat": now,
        "nbf": now,
        "exp": now + 300,
    }

    headers = {"kid": KID}

    token = jwt.encode(
        payload,
        KEY,
        algorithm="RS384",
        headers=headers,
    )
    return token

@app.get("/get_or_create_epic_token")
def get_or_create_epic_token(id: int) -> str:
    with SessionLocal() as db:
        cfg: EpicAppConfig | None = (
            db.query(EpicAppConfig)
              .filter(EpicAppConfig.id == id)
              .first()
        )
        if not cfg:
            raise HTTPException(
                status_code=404,
                detail=f"No EpicAppConfig found for id={id}",
            )

        client_id = cfg.client_id
        fhir_base = cfg.fhir_base
        token_endpoint = cfg.token_url
        org_id = cfg.org_id

        # 1) Try cache
        existing = _get_stored_token(db, client_id, fhir_base)
        if existing:
            return existing

        # 2) Build JWT client assertion
        client_assertion = _build_epic_client_assertion(cfg)

        data = {
            "grant_type": "client_credentials",
            "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
            "client_assertion": client_assertion,
        }

        resp = requests.post(
            token_endpoint,
            data=data,
            headers={"Content-Type": "application/x-www-form-urlencoded"},
        )

        token = resp.json()
        access_token = token["access_token"]
        expires_in = token.get("expires_in")
        refresh_token = token.get("refresh_token")

        _store_token(
            db=db,
            client_id=client_id,
            fhir_base=fhir_base,
            token_url=token_endpoint,
            org_id=org_id,
            access_token=access_token,
            expires_in=expires_in,
            refresh_token=refresh_token,
        )

        return access_token

############################################################
# CLI helper to preseed standard scopes
############################################################

if __name__ == "__main__":
    with SessionLocal() as db:
        default_scopes = [
            "ehr.epic.read", "ehr.cerner.read", "ehr.meditech.read",
            "pump.icu_med.read", "pump.bbraun.read", "pump.ivenix.read",
        ]
        for s in default_scopes:
            if not db.query(Scope).filter_by(name=s).first():
                db.add(Scope(id=secrets.token_hex(8), name=s))
        db.commit()
    print("[INIT] Default scopes seeded. Launch with `uvicorn connectors.token_service:app`")
