220 lines
7.7 KiB
Python
220 lines
7.7 KiB
Python
from __future__ import annotations
|
|
|
|
import ipaddress
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Any, Iterable, List, Optional, Tuple
|
|
|
|
from app.core._crypto.signer import Signer
|
|
from .config import dht_config
|
|
from .crdt import LWWSet, HyperLogLog
|
|
from .keys import MembershipKey
|
|
from .store import DHTStore
|
|
|
|
|
|
@dataclass
|
|
class ReachabilityReceipt:
|
|
target_id: str
|
|
issuer_id: str
|
|
asn: Optional[int]
|
|
timestamp: float
|
|
signature: str
|
|
|
|
def as_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"target_id": self.target_id,
|
|
"issuer_id": self.issuer_id,
|
|
"asn": self.asn,
|
|
"timestamp": self.timestamp,
|
|
"signature": self.signature,
|
|
}
|
|
|
|
|
|
def _ip_first_octet(host: str | None) -> Optional[int]:
|
|
if not host:
|
|
return None
|
|
try:
|
|
ip = ipaddress.ip_address(host)
|
|
return int(str(ip).split(".")[0])
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
class MembershipState:
|
|
def __init__(self, node_id: str, signer: Signer):
|
|
self.node_id = node_id
|
|
self.signer = signer
|
|
self.members = LWWSet()
|
|
self.receipts = LWWSet()
|
|
self.hll = HyperLogLog()
|
|
self.n_reports: Dict[str, float] = {}
|
|
self.logical_counter = 0
|
|
|
|
def _bump_counter(self) -> int:
|
|
self.logical_counter += 1
|
|
return self.logical_counter
|
|
|
|
def register_member(
|
|
self,
|
|
node_id: str,
|
|
public_key: str,
|
|
ip: str | None,
|
|
asn: Optional[int],
|
|
metadata: Dict[str, Any] | None = None,
|
|
timestamp: Optional[float] = None,
|
|
) -> None:
|
|
payload = {
|
|
"node_id": node_id,
|
|
"public_key": public_key,
|
|
"ip": ip,
|
|
"asn": asn,
|
|
"ip_first_octet": _ip_first_octet(ip),
|
|
"meta": metadata or {},
|
|
"last_update": timestamp or time.time(),
|
|
}
|
|
self.members.add(node_id, payload, logical_counter=self._bump_counter(), node_id=self.node_id, timestamp=timestamp)
|
|
self.hll.add(node_id)
|
|
|
|
def forget_member(self, node_id: str) -> None:
|
|
self.members.remove(node_id, logical_counter=self._bump_counter(), node_id=self.node_id)
|
|
|
|
def record_receipt(self, receipt: ReachabilityReceipt) -> None:
|
|
element_id = f"{receipt.target_id}:{receipt.issuer_id}"
|
|
self.receipts.add(
|
|
element_id,
|
|
receipt.as_dict(),
|
|
logical_counter=self._bump_counter(),
|
|
node_id=self.node_id,
|
|
timestamp=receipt.timestamp,
|
|
)
|
|
|
|
def report_local_population(self) -> None:
|
|
self.n_reports[self.node_id] = float(self.hll.estimate())
|
|
|
|
def merge(self, other: "MembershipState") -> "MembershipState":
|
|
self.members.merge(other.members)
|
|
self.receipts.merge(other.receipts)
|
|
self.hll.merge(other.hll)
|
|
for node_id, value in other.n_reports.items():
|
|
self.n_reports[node_id] = max(self.n_reports.get(node_id, 0.0), value)
|
|
self.logical_counter = max(self.logical_counter, other.logical_counter)
|
|
return self
|
|
|
|
def _unique_asn_for(self, node_id: str) -> Tuple[int, Iterable[int]]:
|
|
receipts = [
|
|
entry
|
|
for rid, entry in self.receipts.elements().items()
|
|
if entry.get("target_id") == node_id
|
|
]
|
|
unique_asn = {entry.get("asn") for entry in receipts if entry.get("asn") is not None}
|
|
return len(unique_asn), unique_asn
|
|
|
|
def reachability_ratio(self, node_id: str) -> float:
|
|
unique_count, _ = self._unique_asn_for(node_id)
|
|
if dht_config.min_receipts <= 0:
|
|
return 1.0
|
|
return min(1.0, unique_count / dht_config.min_receipts)
|
|
|
|
def active_members(self, include_islands: bool = False) -> List[Dict[str, Any]]:
|
|
now = time.time()
|
|
result = []
|
|
for node_id, data in self.members.elements().items():
|
|
last_update = data.get("last_update") or 0
|
|
if now - last_update > dht_config.membership_ttl:
|
|
continue
|
|
reachability = self.reachability_ratio(node_id)
|
|
if not include_islands and reachability < dht_config.default_q:
|
|
continue
|
|
enriched = dict(data)
|
|
enriched["reachability_ratio"] = reachability
|
|
result.append(enriched)
|
|
return result
|
|
|
|
def n_estimate(self) -> float:
|
|
self.report_local_population()
|
|
active_ids = {m["node_id"] for m in self.active_members(include_islands=True)}
|
|
filtered_reports = [
|
|
value for node_id, value in self.n_reports.items() if node_id in active_ids and self.reachability_ratio(node_id) >= dht_config.default_q
|
|
]
|
|
local_estimate = float(self.hll.estimate())
|
|
if filtered_reports:
|
|
return max(max(filtered_reports), local_estimate)
|
|
return local_estimate
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"members": self.members.to_dict(),
|
|
"receipts": self.receipts.to_dict(),
|
|
"hll": self.hll.to_dict(),
|
|
"reports": dict(self.n_reports),
|
|
"logical_counter": self.logical_counter,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, node_id: str, signer: Signer, data: Dict[str, Any]) -> "MembershipState":
|
|
inst = cls(node_id=node_id, signer=signer)
|
|
if data:
|
|
inst.members = LWWSet.from_dict(data.get("members") or {})
|
|
inst.receipts = LWWSet.from_dict(data.get("receipts") or {})
|
|
inst.hll = HyperLogLog.from_dict(data.get("hll") or {})
|
|
inst.n_reports = {str(k): float(v) for k, v in (data.get("reports") or {}).items()}
|
|
inst.logical_counter = int(data.get("logical_counter") or 0)
|
|
return inst
|
|
|
|
|
|
class MembershipManager:
|
|
def __init__(self, node_id: str, signer: Signer, store: DHTStore):
|
|
self.node_id = node_id
|
|
self.signer = signer
|
|
self.store = store
|
|
self.state = MembershipState(node_id=node_id, signer=signer)
|
|
|
|
def _merge_remote(self, data: Dict[str, Any]) -> None:
|
|
remote_state = MembershipState.from_dict(self.node_id, self.signer, data)
|
|
self.state.merge(remote_state)
|
|
|
|
def ingest_snapshot(self, payload: Dict[str, Any]) -> None:
|
|
self._merge_remote(payload)
|
|
|
|
def register_local(self, public_key: str, ip: str | None, asn: Optional[int], metadata: Dict[str, Any] | None = None) -> None:
|
|
self.state.register_member(self.node_id, public_key=public_key, ip=ip, asn=asn, metadata=metadata)
|
|
self._persist()
|
|
|
|
def update_member(self, node_id: str, **kwargs) -> None:
|
|
meta = kwargs.get("metadata") or {}
|
|
self.state.register_member(
|
|
node_id,
|
|
public_key=kwargs.get("public_key", meta.get("public_key")),
|
|
ip=kwargs.get("ip"),
|
|
asn=kwargs.get("asn"),
|
|
metadata=meta,
|
|
)
|
|
self._persist()
|
|
|
|
def remove_member(self, node_id: str) -> None:
|
|
self.state.forget_member(node_id)
|
|
self._persist()
|
|
|
|
def record_receipt(self, receipt: ReachabilityReceipt) -> None:
|
|
self.state.record_receipt(receipt)
|
|
self._persist()
|
|
|
|
def _persist(self) -> None:
|
|
key = MembershipKey(node_id=self.node_id)
|
|
self.store.put(
|
|
key=str(key),
|
|
fingerprint=key.fingerprint(),
|
|
value=self.state.to_dict(),
|
|
logical_counter=self.state.logical_counter,
|
|
merge_strategy=lambda a, b: MembershipState.from_dict(self.node_id, self.signer, a)
|
|
.merge(MembershipState.from_dict(self.node_id, self.signer, b))
|
|
.to_dict(),
|
|
)
|
|
|
|
def n_estimate(self) -> float:
|
|
return self.state.n_estimate()
|
|
|
|
def active_members(self) -> List[Dict[str, Any]]:
|
|
return self.state.active_members()
|
|
|