uploader-bot/app/core/network/dht/membership.py

220 lines
7.7 KiB
Python

from __future__ import annotations
import ipaddress
import time
from dataclasses import dataclass
from typing import Dict, Any, Iterable, List, Optional, Tuple
from app.core._crypto.signer import Signer
from .config import dht_config
from .crdt import LWWSet, HyperLogLog
from .keys import MembershipKey
from .store import DHTStore
@dataclass
class ReachabilityReceipt:
target_id: str
issuer_id: str
asn: Optional[int]
timestamp: float
signature: str
def as_dict(self) -> Dict[str, Any]:
return {
"target_id": self.target_id,
"issuer_id": self.issuer_id,
"asn": self.asn,
"timestamp": self.timestamp,
"signature": self.signature,
}
def _ip_first_octet(host: str | None) -> Optional[int]:
if not host:
return None
try:
ip = ipaddress.ip_address(host)
return int(str(ip).split(".")[0])
except Exception:
return None
class MembershipState:
def __init__(self, node_id: str, signer: Signer):
self.node_id = node_id
self.signer = signer
self.members = LWWSet()
self.receipts = LWWSet()
self.hll = HyperLogLog()
self.n_reports: Dict[str, float] = {}
self.logical_counter = 0
def _bump_counter(self) -> int:
self.logical_counter += 1
return self.logical_counter
def register_member(
self,
node_id: str,
public_key: str,
ip: str | None,
asn: Optional[int],
metadata: Dict[str, Any] | None = None,
timestamp: Optional[float] = None,
) -> None:
payload = {
"node_id": node_id,
"public_key": public_key,
"ip": ip,
"asn": asn,
"ip_first_octet": _ip_first_octet(ip),
"meta": metadata or {},
"last_update": timestamp or time.time(),
}
self.members.add(node_id, payload, logical_counter=self._bump_counter(), node_id=self.node_id, timestamp=timestamp)
self.hll.add(node_id)
def forget_member(self, node_id: str) -> None:
self.members.remove(node_id, logical_counter=self._bump_counter(), node_id=self.node_id)
def record_receipt(self, receipt: ReachabilityReceipt) -> None:
element_id = f"{receipt.target_id}:{receipt.issuer_id}"
self.receipts.add(
element_id,
receipt.as_dict(),
logical_counter=self._bump_counter(),
node_id=self.node_id,
timestamp=receipt.timestamp,
)
def report_local_population(self) -> None:
self.n_reports[self.node_id] = float(self.hll.estimate())
def merge(self, other: "MembershipState") -> "MembershipState":
self.members.merge(other.members)
self.receipts.merge(other.receipts)
self.hll.merge(other.hll)
for node_id, value in other.n_reports.items():
self.n_reports[node_id] = max(self.n_reports.get(node_id, 0.0), value)
self.logical_counter = max(self.logical_counter, other.logical_counter)
return self
def _unique_asn_for(self, node_id: str) -> Tuple[int, Iterable[int]]:
receipts = [
entry
for rid, entry in self.receipts.elements().items()
if entry.get("target_id") == node_id
]
unique_asn = {entry.get("asn") for entry in receipts if entry.get("asn") is not None}
return len(unique_asn), unique_asn
def reachability_ratio(self, node_id: str) -> float:
unique_count, _ = self._unique_asn_for(node_id)
if dht_config.min_receipts <= 0:
return 1.0
return min(1.0, unique_count / dht_config.min_receipts)
def active_members(self, include_islands: bool = False) -> List[Dict[str, Any]]:
now = time.time()
result = []
for node_id, data in self.members.elements().items():
last_update = data.get("last_update") or 0
if now - last_update > dht_config.membership_ttl:
continue
reachability = self.reachability_ratio(node_id)
if not include_islands and reachability < dht_config.default_q:
continue
enriched = dict(data)
enriched["reachability_ratio"] = reachability
result.append(enriched)
return result
def n_estimate(self) -> float:
self.report_local_population()
active_ids = {m["node_id"] for m in self.active_members(include_islands=True)}
filtered_reports = [
value for node_id, value in self.n_reports.items() if node_id in active_ids and self.reachability_ratio(node_id) >= dht_config.default_q
]
local_estimate = float(self.hll.estimate())
if filtered_reports:
return max(max(filtered_reports), local_estimate)
return local_estimate
def to_dict(self) -> Dict[str, Any]:
return {
"members": self.members.to_dict(),
"receipts": self.receipts.to_dict(),
"hll": self.hll.to_dict(),
"reports": dict(self.n_reports),
"logical_counter": self.logical_counter,
}
@classmethod
def from_dict(cls, node_id: str, signer: Signer, data: Dict[str, Any]) -> "MembershipState":
inst = cls(node_id=node_id, signer=signer)
if data:
inst.members = LWWSet.from_dict(data.get("members") or {})
inst.receipts = LWWSet.from_dict(data.get("receipts") or {})
inst.hll = HyperLogLog.from_dict(data.get("hll") or {})
inst.n_reports = {str(k): float(v) for k, v in (data.get("reports") or {}).items()}
inst.logical_counter = int(data.get("logical_counter") or 0)
return inst
class MembershipManager:
def __init__(self, node_id: str, signer: Signer, store: DHTStore):
self.node_id = node_id
self.signer = signer
self.store = store
self.state = MembershipState(node_id=node_id, signer=signer)
def _merge_remote(self, data: Dict[str, Any]) -> None:
remote_state = MembershipState.from_dict(self.node_id, self.signer, data)
self.state.merge(remote_state)
def ingest_snapshot(self, payload: Dict[str, Any]) -> None:
self._merge_remote(payload)
def register_local(self, public_key: str, ip: str | None, asn: Optional[int], metadata: Dict[str, Any] | None = None) -> None:
self.state.register_member(self.node_id, public_key=public_key, ip=ip, asn=asn, metadata=metadata)
self._persist()
def update_member(self, node_id: str, **kwargs) -> None:
meta = kwargs.get("metadata") or {}
self.state.register_member(
node_id,
public_key=kwargs.get("public_key", meta.get("public_key")),
ip=kwargs.get("ip"),
asn=kwargs.get("asn"),
metadata=meta,
)
self._persist()
def remove_member(self, node_id: str) -> None:
self.state.forget_member(node_id)
self._persist()
def record_receipt(self, receipt: ReachabilityReceipt) -> None:
self.state.record_receipt(receipt)
self._persist()
def _persist(self) -> None:
key = MembershipKey(node_id=self.node_id)
self.store.put(
key=str(key),
fingerprint=key.fingerprint(),
value=self.state.to_dict(),
logical_counter=self.state.logical_counter,
merge_strategy=lambda a, b: MembershipState.from_dict(self.node_id, self.signer, a)
.merge(MembershipState.from_dict(self.node_id, self.signer, b))
.to_dict(),
)
def n_estimate(self) -> float:
return self.state.n_estimate()
def active_members(self) -> List[Dict[str, Any]]:
return self.state.active_members()