"""
SmartStop Sentinel Connect – Module 6
Immutable Ledger Writer (MySQL-backed)
--------------------------------------
Consumes ACK packets from Kafka topic `events.exported`, appends each to an
append-only MySQL table. Each new row stores SHA-256 hash of the serialized
payload concatenated with the previous row hash, forming a simple block chain.
At 23:59 UTC the service computes a Merkle root of that day’s blocks,
optionally POSTs it to an external notarization endpoint, and logs the digest
back into the chain.
"""
from __future__ import annotations
import hashlib
import json
import os
import signal
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional

from confluent_kafka import Consumer, KafkaError
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    Integer,
    String,
    Text,
    select,
    text,
)

# -------------------------------------------------------------------
# Kafka config
# -------------------------------------------------------------------
BOOTSTRAP = os.getenv("KAFKA_BOOTSTRAP", "localhost:9092")
EXPORTED_TOPIC = os.getenv("EXPORTED_TOPIC", "events.exported")
GROUP_ID = os.getenv("GROUP_ID", "ledger-writer")

consumer = Consumer(
    {
        "bootstrap.servers": BOOTSTRAP,
        "group.id": GROUP_ID,
        "auto.offset.reset": "earliest",
    }
)

# -------------------------------------------------------------------
# MySQL config (edit or set via env)
# -------------------------------------------------------------------
# Example:
#   export LEDGER_DATABASE_URL="mysql+pymysql://user:pass@localhost:3306/token_service"
DATABASE_URL = os.getenv(
    "LEDGER_DATABASE_URL",
    "mysql+pymysql://root:zsVmczGzq9D0NRq1ZseKUrwyucWd0w@localhost:3306/token_service",  # <-- change me
)

engine = create_engine(DATABASE_URL, pool_pre_ping=True)

metadata = MetaData()

ledger_table = Table(
    "ledger",
    metadata,
    Column("id", Integer, primary_key=True, autoincrement=True),
    Column("ts", String(32), nullable=False),         # ISO 8601 UTC string
    Column("payload", Text, nullable=False),
    Column("prev_hash", String(64)),
    Column("hash", String(64), nullable=False, index=True),
    mysql_engine="InnoDB",
    mysql_charset="utf8mb4",
)

metadata.create_all(engine)

# -------------------------------------------------------------------
# Helpers
# -------------------------------------------------------------------

def sha256(s: str) -> str:
    return hashlib.sha256(s.encode()).hexdigest()


def last_hash() -> Optional[str]:
    """Return hash of last block in chain, or None if empty."""
    with engine.connect() as conn:
        result = conn.execute(
            text("SELECT hash FROM ledger ORDER BY id DESC LIMIT 1")
        )
        row = result.fetchone()
        return row[0] if row else None


def append_block(payload: Dict) -> None:
    """Append a block to the ledger table."""
    j = json.dumps(payload, separators=(",", ":"), sort_keys=True)
    prev = last_hash() or "0" * 64
    block_hash = sha256(j + prev)
    ts = datetime.utcnow().isoformat()

    with engine.begin() as conn:
        conn.execute(
            ledger_table.insert().values(
                ts=ts,
                payload=j,
                prev_hash=prev,
                hash=block_hash,
            )
        )


def compute_merkle_root(records: List[str]) -> Optional[str]:
    """Compute Merkle root of a list of serialized payload strings."""
    layer = [sha256(r) for r in records]
    while len(layer) > 1:
        nxt = []
        for i in range(0, len(layer), 2):
            a = layer[i]
            b = layer[i + 1] if i + 1 < len(layer) else a  # duplicate if odd
            nxt.append(sha256(a + b))
        layer = nxt
    return layer[0] if layer else None


def notarise_daily() -> None:
    """Compute daily Merkle root for UTC day and append a marker block."""
    today = datetime.utcnow().date()
    t0 = datetime.combine(today, datetime.min.time(), tzinfo=timezone.utc)
    t1 = t0 + timedelta(days=1)

    # We stored ts as ISO string; string comparison is fine for ISO 8601 UTC
    with engine.connect() as conn:
        result = conn.execute(
            text(
                "SELECT payload FROM ledger WHERE ts >= :start AND ts < :end"
            ),
            {"start": t0.isoformat(), "end": t1.isoformat()},
        )
        payloads = [row[0] for row in result.fetchall()]

    if not payloads:
        return

    root = compute_merkle_root(payloads)
    marker = {"merkleroot": root, "day": str(today)}
    append_block(marker)

    ep = os.getenv("NOTARY_ENDPOINT")
    if ep:
        import requests

        try:
            requests.post(ep, json=marker, timeout=5)
        except Exception:
            # Best-effort; we still keep the marker on-chain
            pass


# -------------------------------------------------------------------
# graceful shutdown
# -------------------------------------------------------------------
RUN = True

def term(sig, frame):
    global RUN
    RUN = False

signal.signal(signal.SIGINT, term)

# -------------------------------------------------------------------
# main loop
# -------------------------------------------------------------------
consumer.subscribe([EXPORTED_TOPIC])
print("[ledger] listening …")
next_midnight = datetime.utcnow().replace(
    hour=23, minute=59, second=59, microsecond=0
)

while RUN:
    msg = consumer.poll(1.0)
    if msg is None:
        pass
    elif msg.error() and msg.error().code() != KafkaError._PARTITION_EOF:
        print("[ledger] Kafka error", msg.error())
    else:
        try:
            payload = json.loads(msg.value())
            append_block(payload)
        except Exception as exc:
            print("[ledger] parse error", exc)

    if datetime.utcnow() >= next_midnight:
        notarise_daily()
        next_midnight += timedelta(days=1)

consumer.close()
print("ledger stopped")
