uploader-bot/app/core/network/dht/crdt.py

279 lines
9.8 KiB
Python

from __future__ import annotations
import math
import time
from dataclasses import dataclass, field
from typing import Dict, Any, Iterable, Tuple
from app.core._utils.hash import blake3_hex
class CRDTMergeError(RuntimeError):
pass
class CRDT:
def merge(self, other: "CRDT") -> "CRDT":
raise NotImplementedError
def to_dict(self) -> Dict[str, Any]:
raise NotImplementedError
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CRDT":
raise NotImplementedError
@dataclass
class LWWElement:
value: Any
logical_counter: int
timestamp: float
node_id: str
def dominates(self, other: "LWWElement") -> bool:
if self.logical_counter > other.logical_counter:
return True
if self.logical_counter < other.logical_counter:
return False
if self.timestamp > other.timestamp:
return True
if self.timestamp < other.timestamp:
return False
# Break all ties by NodeID ordering to guarantee determinism
return self.node_id > other.node_id
class LWWRegister(CRDT):
def __init__(self, element: LWWElement | None = None):
self.element = element
def assign(self, value: Any, logical_counter: int, node_id: str, timestamp: float | None = None) -> None:
new_el = LWWElement(value=value, logical_counter=logical_counter, timestamp=timestamp or time.time(), node_id=node_id)
if self.element is None or new_el.dominates(self.element):
self.element = new_el
def merge(self, other: "LWWRegister") -> "LWWRegister":
if other.element and (self.element is None or other.element.dominates(self.element)):
self.element = other.element
return self
def value(self) -> Any:
return self.element.value if self.element else None
def to_dict(self) -> Dict[str, Any]:
if not self.element:
return {}
return {
"value": self.element.value,
"logical_counter": self.element.logical_counter,
"timestamp": self.element.timestamp,
"node_id": self.element.node_id,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "LWWRegister":
if not data:
return cls()
element = LWWElement(
value=data.get("value"),
logical_counter=int(data["logical_counter"]),
timestamp=float(data["timestamp"]),
node_id=str(data["node_id"]),
)
return cls(element=element)
class LWWSet(CRDT):
def __init__(self, adds: Dict[str, LWWElement] | None = None, removes: Dict[str, LWWElement] | None = None):
self.adds: Dict[str, LWWElement] = adds or {}
self.removes: Dict[str, LWWElement] = removes or {}
def add(self, element_id: str, value: Any, logical_counter: int, node_id: str, timestamp: float | None = None) -> None:
elem = LWWElement(value=value, logical_counter=logical_counter, timestamp=timestamp or time.time(), node_id=node_id)
existing = self.adds.get(element_id)
if not existing or elem.dominates(existing):
self.adds[element_id] = elem
def remove(self, element_id: str, logical_counter: int, node_id: str, timestamp: float | None = None) -> None:
elem = LWWElement(value=None, logical_counter=logical_counter, timestamp=timestamp or time.time(), node_id=node_id)
existing = self.removes.get(element_id)
if not existing or elem.dominates(existing):
self.removes[element_id] = elem
def lookup(self, element_id: str) -> Any | None:
add = self.adds.get(element_id)
remove = self.removes.get(element_id)
if add and (not remove or add.dominates(remove)):
return add.value
return None
def elements(self) -> Dict[str, Any]:
return {eid: elem.value for eid, elem in self.adds.items() if self.lookup(eid) is not None}
def merge(self, other: "LWWSet") -> "LWWSet":
for eid, elem in other.adds.items():
current = self.adds.get(eid)
if not current or elem.dominates(current):
self.adds[eid] = elem
for eid, elem in other.removes.items():
current = self.removes.get(eid)
if not current or elem.dominates(current):
self.removes[eid] = elem
return self
def to_dict(self) -> Dict[str, Any]:
def serialize_map(source: Dict[str, LWWElement]) -> Dict[str, Dict[str, Any]]:
return {
eid: {
"value": elem.value,
"logical_counter": elem.logical_counter,
"timestamp": elem.timestamp,
"node_id": elem.node_id,
}
for eid, elem in source.items()
}
return {"adds": serialize_map(self.adds), "removes": serialize_map(self.removes)}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "LWWSet":
adds = {
eid: LWWElement(
value=elem.get("value"),
logical_counter=int(elem["logical_counter"]),
timestamp=float(elem["timestamp"]),
node_id=str(elem["node_id"]),
)
for eid, elem in (data.get("adds") or {}).items()
}
removes = {
eid: LWWElement(
value=elem.get("value"),
logical_counter=int(elem["logical_counter"]),
timestamp=float(elem["timestamp"]),
node_id=str(elem["node_id"]),
)
for eid, elem in (data.get("removes") or {}).items()
}
return cls(adds=adds, removes=removes)
class PNCounter(CRDT):
def __init__(self, increments: Dict[str, int] | None = None, decrements: Dict[str, int] | None = None):
self.increments = increments or {}
self.decrements = decrements or {}
def increment(self, node_id: str, value: int = 1) -> None:
if value < 0:
raise ValueError("value must be non-negative for increment")
self.increments[node_id] = self.increments.get(node_id, 0) + value
def decrement(self, node_id: str, value: int = 1) -> None:
if value < 0:
raise ValueError("value must be non-negative for decrement")
self.decrements[node_id] = self.decrements.get(node_id, 0) + value
def value(self) -> int:
return sum(self.increments.values()) - sum(self.decrements.values())
def merge(self, other: "PNCounter") -> "PNCounter":
for nid, val in other.increments.items():
self.increments[nid] = max(self.increments.get(nid, 0), val)
for nid, val in other.decrements.items():
self.decrements[nid] = max(self.decrements.get(nid, 0), val)
return self
def to_dict(self) -> Dict[str, Any]:
return {"inc": dict(self.increments), "dec": dict(self.decrements)}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "PNCounter":
return cls(increments=dict(data.get("inc") or {}), decrements=dict(data.get("dec") or {}))
class GCounter(CRDT):
def __init__(self, counters: Dict[str, int] | None = None):
self.counters = counters or {}
def increment(self, node_id: str, value: int = 1) -> None:
if value < 0:
raise ValueError("value must be non-negative")
self.counters[node_id] = self.counters.get(node_id, 0) + value
def value(self) -> int:
return sum(self.counters.values())
def merge(self, other: "GCounter") -> "GCounter":
for nid, val in other.counters.items():
self.counters[nid] = max(self.counters.get(nid, 0), val)
return self
def to_dict(self) -> Dict[str, Any]:
return dict(self.counters)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GCounter":
return cls(counters=dict(data or {}))
def _leading_zeros(value: int, width: int) -> int:
if value == 0:
return width
return width - value.bit_length()
@dataclass
class HyperLogLog(CRDT):
precision: int = 12
registers: Tuple[int, ...] = field(default_factory=tuple)
def __post_init__(self) -> None:
if not self.registers:
self.registers = tuple([0] * (1 << self.precision))
else:
self.registers = tuple(self.registers)
@property
def m(self) -> int:
return len(self.registers)
def add(self, value: Any) -> None:
if value is None:
return
hashed = int(blake3_hex(str(value).encode()), 16)
index = hashed & (self.m - 1)
w = hashed >> self.precision
rank = _leading_zeros(w, 256 - self.precision) + 1
current = self.registers[index]
if rank > current:
regs = list(self.registers)
regs[index] = rank
self.registers = tuple(regs)
def estimate(self) -> float:
alpha = 0.7213 / (1 + 1.079 / self.m)
indicator = sum(2.0 ** (-r) for r in self.registers)
raw = alpha * (self.m ** 2) / indicator
if raw <= 2.5 * self.m:
zeros = self.registers.count(0)
if zeros:
return self.m * math.log(self.m / zeros)
return raw
def merge(self, other: "HyperLogLog") -> "HyperLogLog":
if self.m != other.m:
raise CRDTMergeError("Cannot merge HyperLogLog instances with different precision")
merged = [max(a, b) for a, b in zip(self.registers, other.registers)]
self.registers = tuple(merged)
return self
def to_dict(self) -> Dict[str, Any]:
return {"precision": self.precision, "registers": list(self.registers)}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "HyperLogLog":
if not data:
return cls()
return cls(precision=int(data.get("precision", 12)), registers=tuple(int(x) for x in data.get("registers", [])))