from __future__ import annotations import json from datetime import datetime, timedelta, timezone from typing import Any, Dict, Iterable, List, Optional from uuid import uuid4 from base58 import b58decode, b58encode import nacl.signing from sqlalchemy import select, delete from sqlalchemy.ext.asyncio import AsyncSession from app.core.logger import make_log from app.core._secrets import hot_pubkey, hot_seed from app.core.models import NodeEvent, NodeEventCursor LOCAL_PUBLIC_KEY = b58encode(hot_pubkey).decode() def _normalize_dt(value: Optional[datetime]) -> datetime: if value is None: return datetime.utcnow() if value.tzinfo is not None: return value.astimezone(timezone.utc).replace(tzinfo=None) return value def _parse_iso_dt(iso_value: Optional[str]) -> datetime: if not iso_value: return datetime.utcnow() try: parsed = datetime.fromisoformat(iso_value.replace('Z', '+00:00')) except Exception: return datetime.utcnow() return _normalize_dt(parsed) def _canonical_blob(data: Dict[str, Any]) -> bytes: return json.dumps(data, sort_keys=True, separators=(",", ":")).encode() def _sign_event(blob: Dict[str, Any]) -> str: signing_key = nacl.signing.SigningKey(hot_seed) signature = signing_key.sign(_canonical_blob(blob)).signature return b58encode(signature).decode() def verify_event_signature(event: Dict[str, Any]) -> bool: try: origin_key = event["origin_public_key"] signature = event["signature"] payload = { "origin_public_key": origin_key, "origin_host": event.get("origin_host"), "seq": event["seq"], "uid": event["uid"], "event_type": event["event_type"], "payload": event.get("payload") or {}, "created_at": event.get("created_at"), } verify_key = nacl.signing.VerifyKey(b58decode(origin_key)) verify_key.verify(_canonical_blob(payload), b58decode(signature)) return True except Exception as exc: make_log("Events", f"Signature validation failed: {exc}", level="warning") return False async def next_local_seq(session: AsyncSession) -> int: result = await session.execute( select(NodeEvent.seq) .where(NodeEvent.origin_public_key == LOCAL_PUBLIC_KEY) .order_by(NodeEvent.seq.desc()) .limit(1) ) row = result.scalar_one_or_none() return int(row or 0) + 1 async def record_event( session: AsyncSession, event_type: str, payload: Dict[str, Any], origin_host: Optional[str] = None, created_at: Optional[datetime] = None, ) -> NodeEvent: seq = await next_local_seq(session) created_dt = _normalize_dt(created_at) event_body = { "origin_public_key": LOCAL_PUBLIC_KEY, "origin_host": origin_host, "seq": seq, "uid": uuid4().hex, "event_type": event_type, "payload": payload, "created_at": created_dt.replace(tzinfo=timezone.utc).isoformat().replace('+00:00', 'Z'), } signature = _sign_event(event_body) node_event = NodeEvent( origin_public_key=LOCAL_PUBLIC_KEY, origin_host=origin_host, seq=seq, uid=event_body["uid"], event_type=event_type, payload=payload, signature=signature, created_at=created_dt, status='local', ) session.add(node_event) await session.flush() make_log("Events", f"Recorded local event {event_type} seq={seq}") return node_event async def upsert_cursor(session: AsyncSession, source_public_key: str, seq: int, host: Optional[str]): existing = (await session.execute( select(NodeEventCursor).where(NodeEventCursor.source_public_key == source_public_key) )).scalar_one_or_none() if existing: if seq > existing.last_seq: existing.last_seq = seq if host: existing.source_public_host = host else: cursor = NodeEventCursor( source_public_key=source_public_key, last_seq=seq, source_public_host=host, ) session.add(cursor) await session.flush() async def store_remote_events( session: AsyncSession, events: Iterable[Dict[str, Any]], allowed_public_keys: Optional[set[str]] = None, ) -> List[NodeEvent]: stored: List[NodeEvent] = [] for event in events: if not verify_event_signature(event): continue origin_pk = event["origin_public_key"] if allowed_public_keys is not None and origin_pk not in allowed_public_keys: make_log("Events", f"Ignored event from untrusted node {origin_pk}", level="warning") continue seq = int(event["seq"]) exists = (await session.execute( select(NodeEvent).where( NodeEvent.origin_public_key == origin_pk, NodeEvent.seq == seq, ) )).scalar_one_or_none() if exists: continue created_dt = _parse_iso_dt(event.get("created_at")) received_dt = datetime.utcnow() node_event = NodeEvent( origin_public_key=origin_pk, origin_host=event.get("origin_host"), seq=seq, uid=event["uid"], event_type=event["event_type"], payload=event.get("payload") or {}, signature=event["signature"], created_at=created_dt, status='recorded', received_at=received_dt, ) session.add(node_event) stored.append(node_event) await upsert_cursor(session, origin_pk, seq, event.get("origin_host")) make_log("Events", f"Ingested remote event {event['event_type']} from {origin_pk} seq={seq}", level="debug") if stored: await session.flush() return stored async def prune_events(session: AsyncSession, max_age_days: int = 90): cutoff = datetime.utcnow() - timedelta(days=max_age_days) await session.execute( delete(NodeEvent).where(NodeEvent.created_at < cutoff) )