from __future__ import annotations import math import time from dataclasses import dataclass, field from typing import Dict, Any, Iterable, Tuple from app.core._utils.hash import blake3_hex class CRDTMergeError(RuntimeError): pass class CRDT: def merge(self, other: "CRDT") -> "CRDT": raise NotImplementedError def to_dict(self) -> Dict[str, Any]: raise NotImplementedError @classmethod def from_dict(cls, data: Dict[str, Any]) -> "CRDT": raise NotImplementedError @dataclass class LWWElement: value: Any logical_counter: int timestamp: float node_id: str def dominates(self, other: "LWWElement") -> bool: if self.logical_counter > other.logical_counter: return True if self.logical_counter < other.logical_counter: return False if self.timestamp > other.timestamp: return True if self.timestamp < other.timestamp: return False # Break all ties by NodeID ordering to guarantee determinism return self.node_id > other.node_id class LWWRegister(CRDT): def __init__(self, element: LWWElement | None = None): self.element = element def assign(self, value: Any, logical_counter: int, node_id: str, timestamp: float | None = None) -> None: new_el = LWWElement(value=value, logical_counter=logical_counter, timestamp=timestamp or time.time(), node_id=node_id) if self.element is None or new_el.dominates(self.element): self.element = new_el def merge(self, other: "LWWRegister") -> "LWWRegister": if other.element and (self.element is None or other.element.dominates(self.element)): self.element = other.element return self def value(self) -> Any: return self.element.value if self.element else None def to_dict(self) -> Dict[str, Any]: if not self.element: return {} return { "value": self.element.value, "logical_counter": self.element.logical_counter, "timestamp": self.element.timestamp, "node_id": self.element.node_id, } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "LWWRegister": if not data: return cls() element = LWWElement( value=data.get("value"), logical_counter=int(data["logical_counter"]), timestamp=float(data["timestamp"]), node_id=str(data["node_id"]), ) return cls(element=element) class LWWSet(CRDT): def __init__(self, adds: Dict[str, LWWElement] | None = None, removes: Dict[str, LWWElement] | None = None): self.adds: Dict[str, LWWElement] = adds or {} self.removes: Dict[str, LWWElement] = removes or {} def add(self, element_id: str, value: Any, logical_counter: int, node_id: str, timestamp: float | None = None) -> None: elem = LWWElement(value=value, logical_counter=logical_counter, timestamp=timestamp or time.time(), node_id=node_id) existing = self.adds.get(element_id) if not existing or elem.dominates(existing): self.adds[element_id] = elem def remove(self, element_id: str, logical_counter: int, node_id: str, timestamp: float | None = None) -> None: elem = LWWElement(value=None, logical_counter=logical_counter, timestamp=timestamp or time.time(), node_id=node_id) existing = self.removes.get(element_id) if not existing or elem.dominates(existing): self.removes[element_id] = elem def lookup(self, element_id: str) -> Any | None: add = self.adds.get(element_id) remove = self.removes.get(element_id) if add and (not remove or add.dominates(remove)): return add.value return None def elements(self) -> Dict[str, Any]: return {eid: elem.value for eid, elem in self.adds.items() if self.lookup(eid) is not None} def merge(self, other: "LWWSet") -> "LWWSet": for eid, elem in other.adds.items(): current = self.adds.get(eid) if not current or elem.dominates(current): self.adds[eid] = elem for eid, elem in other.removes.items(): current = self.removes.get(eid) if not current or elem.dominates(current): self.removes[eid] = elem return self def to_dict(self) -> Dict[str, Any]: def serialize_map(source: Dict[str, LWWElement]) -> Dict[str, Dict[str, Any]]: return { eid: { "value": elem.value, "logical_counter": elem.logical_counter, "timestamp": elem.timestamp, "node_id": elem.node_id, } for eid, elem in source.items() } return {"adds": serialize_map(self.adds), "removes": serialize_map(self.removes)} @classmethod def from_dict(cls, data: Dict[str, Any]) -> "LWWSet": adds = { eid: LWWElement( value=elem.get("value"), logical_counter=int(elem["logical_counter"]), timestamp=float(elem["timestamp"]), node_id=str(elem["node_id"]), ) for eid, elem in (data.get("adds") or {}).items() } removes = { eid: LWWElement( value=elem.get("value"), logical_counter=int(elem["logical_counter"]), timestamp=float(elem["timestamp"]), node_id=str(elem["node_id"]), ) for eid, elem in (data.get("removes") or {}).items() } return cls(adds=adds, removes=removes) class PNCounter(CRDT): def __init__(self, increments: Dict[str, int] | None = None, decrements: Dict[str, int] | None = None): self.increments = increments or {} self.decrements = decrements or {} def increment(self, node_id: str, value: int = 1) -> None: if value < 0: raise ValueError("value must be non-negative for increment") self.increments[node_id] = self.increments.get(node_id, 0) + value def decrement(self, node_id: str, value: int = 1) -> None: if value < 0: raise ValueError("value must be non-negative for decrement") self.decrements[node_id] = self.decrements.get(node_id, 0) + value def value(self) -> int: return sum(self.increments.values()) - sum(self.decrements.values()) def merge(self, other: "PNCounter") -> "PNCounter": for nid, val in other.increments.items(): self.increments[nid] = max(self.increments.get(nid, 0), val) for nid, val in other.decrements.items(): self.decrements[nid] = max(self.decrements.get(nid, 0), val) return self def to_dict(self) -> Dict[str, Any]: return {"inc": dict(self.increments), "dec": dict(self.decrements)} @classmethod def from_dict(cls, data: Dict[str, Any]) -> "PNCounter": return cls(increments=dict(data.get("inc") or {}), decrements=dict(data.get("dec") or {})) class GCounter(CRDT): def __init__(self, counters: Dict[str, int] | None = None): self.counters = counters or {} def increment(self, node_id: str, value: int = 1) -> None: if value < 0: raise ValueError("value must be non-negative") self.counters[node_id] = self.counters.get(node_id, 0) + value def value(self) -> int: return sum(self.counters.values()) def merge(self, other: "GCounter") -> "GCounter": for nid, val in other.counters.items(): self.counters[nid] = max(self.counters.get(nid, 0), val) return self def to_dict(self) -> Dict[str, Any]: return dict(self.counters) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "GCounter": return cls(counters=dict(data or {})) def _leading_zeros(value: int, width: int) -> int: if value == 0: return width return width - value.bit_length() @dataclass class HyperLogLog(CRDT): precision: int = 12 registers: Tuple[int, ...] = field(default_factory=tuple) def __post_init__(self) -> None: if not self.registers: self.registers = tuple([0] * (1 << self.precision)) else: self.registers = tuple(self.registers) @property def m(self) -> int: return len(self.registers) def add(self, value: Any) -> None: if value is None: return hashed = int(blake3_hex(str(value).encode()), 16) index = hashed & (self.m - 1) w = hashed >> self.precision rank = _leading_zeros(w, 256 - self.precision) + 1 current = self.registers[index] if rank > current: regs = list(self.registers) regs[index] = rank self.registers = tuple(regs) def estimate(self) -> float: alpha = 0.7213 / (1 + 1.079 / self.m) indicator = sum(2.0 ** (-r) for r in self.registers) raw = alpha * (self.m ** 2) / indicator if raw <= 2.5 * self.m: zeros = self.registers.count(0) if zeros: return self.m * math.log(self.m / zeros) return raw def merge(self, other: "HyperLogLog") -> "HyperLogLog": if self.m != other.m: raise CRDTMergeError("Cannot merge HyperLogLog instances with different precision") merged = [max(a, b) for a, b in zip(self.registers, other.registers)] self.registers = tuple(merged) return self def to_dict(self) -> Dict[str, Any]: return {"precision": self.precision, "registers": list(self.registers)} @classmethod def from_dict(cls, data: Dict[str, Any]) -> "HyperLogLog": if not data: return cls() return cls(precision=int(data.get("precision", 12)), registers=tuple(int(x) for x in data.get("registers", [])))