279 lines
9.8 KiB
Python
279 lines
9.8 KiB
Python
from __future__ import annotations
|
|
|
|
import math
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, Any, Iterable, Tuple
|
|
|
|
from app.core._utils.hash import blake3_hex
|
|
|
|
|
|
class CRDTMergeError(RuntimeError):
|
|
pass
|
|
|
|
|
|
class CRDT:
|
|
def merge(self, other: "CRDT") -> "CRDT":
|
|
raise NotImplementedError
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "CRDT":
|
|
raise NotImplementedError
|
|
|
|
|
|
@dataclass
|
|
class LWWElement:
|
|
value: Any
|
|
logical_counter: int
|
|
timestamp: float
|
|
node_id: str
|
|
|
|
def dominates(self, other: "LWWElement") -> bool:
|
|
if self.logical_counter > other.logical_counter:
|
|
return True
|
|
if self.logical_counter < other.logical_counter:
|
|
return False
|
|
if self.timestamp > other.timestamp:
|
|
return True
|
|
if self.timestamp < other.timestamp:
|
|
return False
|
|
# Break all ties by NodeID ordering to guarantee determinism
|
|
return self.node_id > other.node_id
|
|
|
|
|
|
class LWWRegister(CRDT):
|
|
def __init__(self, element: LWWElement | None = None):
|
|
self.element = element
|
|
|
|
def assign(self, value: Any, logical_counter: int, node_id: str, timestamp: float | None = None) -> None:
|
|
new_el = LWWElement(value=value, logical_counter=logical_counter, timestamp=timestamp or time.time(), node_id=node_id)
|
|
if self.element is None or new_el.dominates(self.element):
|
|
self.element = new_el
|
|
|
|
def merge(self, other: "LWWRegister") -> "LWWRegister":
|
|
if other.element and (self.element is None or other.element.dominates(self.element)):
|
|
self.element = other.element
|
|
return self
|
|
|
|
def value(self) -> Any:
|
|
return self.element.value if self.element else None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
if not self.element:
|
|
return {}
|
|
return {
|
|
"value": self.element.value,
|
|
"logical_counter": self.element.logical_counter,
|
|
"timestamp": self.element.timestamp,
|
|
"node_id": self.element.node_id,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "LWWRegister":
|
|
if not data:
|
|
return cls()
|
|
element = LWWElement(
|
|
value=data.get("value"),
|
|
logical_counter=int(data["logical_counter"]),
|
|
timestamp=float(data["timestamp"]),
|
|
node_id=str(data["node_id"]),
|
|
)
|
|
return cls(element=element)
|
|
|
|
|
|
class LWWSet(CRDT):
|
|
def __init__(self, adds: Dict[str, LWWElement] | None = None, removes: Dict[str, LWWElement] | None = None):
|
|
self.adds: Dict[str, LWWElement] = adds or {}
|
|
self.removes: Dict[str, LWWElement] = removes or {}
|
|
|
|
def add(self, element_id: str, value: Any, logical_counter: int, node_id: str, timestamp: float | None = None) -> None:
|
|
elem = LWWElement(value=value, logical_counter=logical_counter, timestamp=timestamp or time.time(), node_id=node_id)
|
|
existing = self.adds.get(element_id)
|
|
if not existing or elem.dominates(existing):
|
|
self.adds[element_id] = elem
|
|
|
|
def remove(self, element_id: str, logical_counter: int, node_id: str, timestamp: float | None = None) -> None:
|
|
elem = LWWElement(value=None, logical_counter=logical_counter, timestamp=timestamp or time.time(), node_id=node_id)
|
|
existing = self.removes.get(element_id)
|
|
if not existing or elem.dominates(existing):
|
|
self.removes[element_id] = elem
|
|
|
|
def lookup(self, element_id: str) -> Any | None:
|
|
add = self.adds.get(element_id)
|
|
remove = self.removes.get(element_id)
|
|
if add and (not remove or add.dominates(remove)):
|
|
return add.value
|
|
return None
|
|
|
|
def elements(self) -> Dict[str, Any]:
|
|
return {eid: elem.value for eid, elem in self.adds.items() if self.lookup(eid) is not None}
|
|
|
|
def merge(self, other: "LWWSet") -> "LWWSet":
|
|
for eid, elem in other.adds.items():
|
|
current = self.adds.get(eid)
|
|
if not current or elem.dominates(current):
|
|
self.adds[eid] = elem
|
|
for eid, elem in other.removes.items():
|
|
current = self.removes.get(eid)
|
|
if not current or elem.dominates(current):
|
|
self.removes[eid] = elem
|
|
return self
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
def serialize_map(source: Dict[str, LWWElement]) -> Dict[str, Dict[str, Any]]:
|
|
return {
|
|
eid: {
|
|
"value": elem.value,
|
|
"logical_counter": elem.logical_counter,
|
|
"timestamp": elem.timestamp,
|
|
"node_id": elem.node_id,
|
|
}
|
|
for eid, elem in source.items()
|
|
}
|
|
|
|
return {"adds": serialize_map(self.adds), "removes": serialize_map(self.removes)}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "LWWSet":
|
|
adds = {
|
|
eid: LWWElement(
|
|
value=elem.get("value"),
|
|
logical_counter=int(elem["logical_counter"]),
|
|
timestamp=float(elem["timestamp"]),
|
|
node_id=str(elem["node_id"]),
|
|
)
|
|
for eid, elem in (data.get("adds") or {}).items()
|
|
}
|
|
removes = {
|
|
eid: LWWElement(
|
|
value=elem.get("value"),
|
|
logical_counter=int(elem["logical_counter"]),
|
|
timestamp=float(elem["timestamp"]),
|
|
node_id=str(elem["node_id"]),
|
|
)
|
|
for eid, elem in (data.get("removes") or {}).items()
|
|
}
|
|
return cls(adds=adds, removes=removes)
|
|
|
|
|
|
class PNCounter(CRDT):
|
|
def __init__(self, increments: Dict[str, int] | None = None, decrements: Dict[str, int] | None = None):
|
|
self.increments = increments or {}
|
|
self.decrements = decrements or {}
|
|
|
|
def increment(self, node_id: str, value: int = 1) -> None:
|
|
if value < 0:
|
|
raise ValueError("value must be non-negative for increment")
|
|
self.increments[node_id] = self.increments.get(node_id, 0) + value
|
|
|
|
def decrement(self, node_id: str, value: int = 1) -> None:
|
|
if value < 0:
|
|
raise ValueError("value must be non-negative for decrement")
|
|
self.decrements[node_id] = self.decrements.get(node_id, 0) + value
|
|
|
|
def value(self) -> int:
|
|
return sum(self.increments.values()) - sum(self.decrements.values())
|
|
|
|
def merge(self, other: "PNCounter") -> "PNCounter":
|
|
for nid, val in other.increments.items():
|
|
self.increments[nid] = max(self.increments.get(nid, 0), val)
|
|
for nid, val in other.decrements.items():
|
|
self.decrements[nid] = max(self.decrements.get(nid, 0), val)
|
|
return self
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {"inc": dict(self.increments), "dec": dict(self.decrements)}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "PNCounter":
|
|
return cls(increments=dict(data.get("inc") or {}), decrements=dict(data.get("dec") or {}))
|
|
|
|
|
|
class GCounter(CRDT):
|
|
def __init__(self, counters: Dict[str, int] | None = None):
|
|
self.counters = counters or {}
|
|
|
|
def increment(self, node_id: str, value: int = 1) -> None:
|
|
if value < 0:
|
|
raise ValueError("value must be non-negative")
|
|
self.counters[node_id] = self.counters.get(node_id, 0) + value
|
|
|
|
def value(self) -> int:
|
|
return sum(self.counters.values())
|
|
|
|
def merge(self, other: "GCounter") -> "GCounter":
|
|
for nid, val in other.counters.items():
|
|
self.counters[nid] = max(self.counters.get(nid, 0), val)
|
|
return self
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return dict(self.counters)
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "GCounter":
|
|
return cls(counters=dict(data or {}))
|
|
|
|
|
|
def _leading_zeros(value: int, width: int) -> int:
|
|
if value == 0:
|
|
return width
|
|
return width - value.bit_length()
|
|
|
|
|
|
@dataclass
|
|
class HyperLogLog(CRDT):
|
|
precision: int = 12
|
|
registers: Tuple[int, ...] = field(default_factory=tuple)
|
|
|
|
def __post_init__(self) -> None:
|
|
if not self.registers:
|
|
self.registers = tuple([0] * (1 << self.precision))
|
|
else:
|
|
self.registers = tuple(self.registers)
|
|
|
|
@property
|
|
def m(self) -> int:
|
|
return len(self.registers)
|
|
|
|
def add(self, value: Any) -> None:
|
|
if value is None:
|
|
return
|
|
hashed = int(blake3_hex(str(value).encode()), 16)
|
|
index = hashed & (self.m - 1)
|
|
w = hashed >> self.precision
|
|
rank = _leading_zeros(w, 256 - self.precision) + 1
|
|
current = self.registers[index]
|
|
if rank > current:
|
|
regs = list(self.registers)
|
|
regs[index] = rank
|
|
self.registers = tuple(regs)
|
|
|
|
def estimate(self) -> float:
|
|
alpha = 0.7213 / (1 + 1.079 / self.m)
|
|
indicator = sum(2.0 ** (-r) for r in self.registers)
|
|
raw = alpha * (self.m ** 2) / indicator
|
|
if raw <= 2.5 * self.m:
|
|
zeros = self.registers.count(0)
|
|
if zeros:
|
|
return self.m * math.log(self.m / zeros)
|
|
return raw
|
|
|
|
def merge(self, other: "HyperLogLog") -> "HyperLogLog":
|
|
if self.m != other.m:
|
|
raise CRDTMergeError("Cannot merge HyperLogLog instances with different precision")
|
|
merged = [max(a, b) for a, b in zip(self.registers, other.registers)]
|
|
self.registers = tuple(merged)
|
|
return self
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {"precision": self.precision, "registers": list(self.registers)}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "HyperLogLog":
|
|
if not data:
|
|
return cls()
|
|
return cls(precision=int(data.get("precision", 12)), registers=tuple(int(x) for x in data.get("registers", [])))
|