uploader-bot/app/core/network/dht/membership.py

236 lines
8.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import ipaddress
import time
from dataclasses import dataclass
from typing import Dict, Any, Iterable, List, Optional, Tuple
from app.core._crypto.signer import Signer
from .config import dht_config
from .crdt import LWWSet, HyperLogLog
from .keys import MembershipKey
from .store import DHTStore
@dataclass
class ReachabilityReceipt:
target_id: str
issuer_id: str
asn: Optional[int]
timestamp: float
signature: str
def as_dict(self) -> Dict[str, Any]:
return {
"target_id": self.target_id,
"issuer_id": self.issuer_id,
"asn": self.asn,
"timestamp": self.timestamp,
"signature": self.signature,
}
def _ip_first_octet(host: str | None) -> Optional[int]:
if not host:
return None
try:
ip = ipaddress.ip_address(host)
return int(str(ip).split(".")[0])
except Exception:
return None
class MembershipState:
def __init__(self, node_id: str, signer: Signer):
self.node_id = node_id
self.signer = signer
self.members = LWWSet()
self.receipts = LWWSet()
self.hll = HyperLogLog()
self.n_reports: Dict[str, float] = {}
self.logical_counter = 0
def _bump_counter(self) -> int:
self.logical_counter += 1
return self.logical_counter
def register_member(
self,
node_id: str,
public_key: str,
ip: str | None,
asn: Optional[int],
metadata: Dict[str, Any] | None = None,
timestamp: Optional[float] = None,
) -> None:
payload = {
"node_id": node_id,
"public_key": public_key,
"ip": ip,
"asn": asn,
"ip_first_octet": _ip_first_octet(ip),
"meta": metadata or {},
"last_update": timestamp or time.time(),
}
self.members.add(node_id, payload, logical_counter=self._bump_counter(), node_id=self.node_id, timestamp=timestamp)
self.hll.add(node_id)
def forget_member(self, node_id: str) -> None:
self.members.remove(node_id, logical_counter=self._bump_counter(), node_id=self.node_id)
def record_receipt(self, receipt: ReachabilityReceipt) -> None:
element_id = f"{receipt.target_id}:{receipt.issuer_id}"
self.receipts.add(
element_id,
receipt.as_dict(),
logical_counter=self._bump_counter(),
node_id=self.node_id,
timestamp=receipt.timestamp,
)
def report_local_population(self) -> None:
self.n_reports[self.node_id] = float(self.hll.estimate())
def merge(self, other: "MembershipState") -> "MembershipState":
self.members.merge(other.members)
self.receipts.merge(other.receipts)
self.hll.merge(other.hll)
for node_id, value in other.n_reports.items():
self.n_reports[node_id] = max(self.n_reports.get(node_id, 0.0), value)
self.logical_counter = max(self.logical_counter, other.logical_counter)
return self
def _unique_asn_for(self, node_id: str) -> Tuple[int, Iterable[int]]:
receipts = [
entry
for rid, entry in self.receipts.elements().items()
if entry.get("target_id") == node_id
]
unique_asn = {entry.get("asn") for entry in receipts if entry.get("asn") is not None}
return len(unique_asn), unique_asn
def reachability_ratio(self, node_id: str) -> float:
unique_count, _ = self._unique_asn_for(node_id)
if dht_config.min_receipts <= 0:
return 1.0
return min(1.0, unique_count / dht_config.min_receipts)
def active_members(self, include_islands: bool = False) -> List[Dict[str, Any]]:
now = time.time()
result = []
for node_id, data in self.members.elements().items():
last_update = data.get("last_update") or 0
if now - last_update > dht_config.membership_ttl:
continue
reachability = self.reachability_ratio(node_id)
if not include_islands and reachability < dht_config.default_q:
continue
enriched = dict(data)
enriched["reachability_ratio"] = reachability
result.append(enriched)
return result
def n_estimate(self) -> float:
self.report_local_population()
active_ids = {m["node_id"] for m in self.active_members(include_islands=True)}
filtered_reports = [
value for node_id, value in self.n_reports.items() if node_id in active_ids and self.reachability_ratio(node_id) >= dht_config.default_q
]
local_estimate = float(self.hll.estimate())
if filtered_reports:
return max(max(filtered_reports), local_estimate)
return local_estimate
def n_estimate_trusted(self, allowed_ids: set[str]) -> float:
"""Оценка размера сети только по trusted узлам.
Берём активных участников, пересекаем с allowed_ids и оцениваем по их числу
и по их N_local репортам (если доступны).
"""
self.report_local_population()
active_trusted = {m["node_id"] for m in self.active_members(include_islands=True) if m.get("node_id") in allowed_ids}
filtered_reports = [
value for node_id, value in self.n_reports.items()
if node_id in active_trusted and self.reachability_ratio(node_id) >= dht_config.default_q
]
# Для доверенных полагаемся на фактическое количество активных Trusted
local_estimate = float(len(active_trusted))
if filtered_reports:
return max(max(filtered_reports), local_estimate)
return local_estimate
def to_dict(self) -> Dict[str, Any]:
return {
"members": self.members.to_dict(),
"receipts": self.receipts.to_dict(),
"hll": self.hll.to_dict(),
"reports": dict(self.n_reports),
"logical_counter": self.logical_counter,
}
@classmethod
def from_dict(cls, node_id: str, signer: Signer, data: Dict[str, Any]) -> "MembershipState":
inst = cls(node_id=node_id, signer=signer)
if data:
inst.members = LWWSet.from_dict(data.get("members") or {})
inst.receipts = LWWSet.from_dict(data.get("receipts") or {})
inst.hll = HyperLogLog.from_dict(data.get("hll") or {})
inst.n_reports = {str(k): float(v) for k, v in (data.get("reports") or {}).items()}
inst.logical_counter = int(data.get("logical_counter") or 0)
return inst
class MembershipManager:
def __init__(self, node_id: str, signer: Signer, store: DHTStore):
self.node_id = node_id
self.signer = signer
self.store = store
self.state = MembershipState(node_id=node_id, signer=signer)
def _merge_remote(self, data: Dict[str, Any]) -> None:
remote_state = MembershipState.from_dict(self.node_id, self.signer, data)
self.state.merge(remote_state)
def ingest_snapshot(self, payload: Dict[str, Any]) -> None:
self._merge_remote(payload)
def register_local(self, public_key: str, ip: str | None, asn: Optional[int], metadata: Dict[str, Any] | None = None) -> None:
self.state.register_member(self.node_id, public_key=public_key, ip=ip, asn=asn, metadata=metadata)
self._persist()
def update_member(self, node_id: str, **kwargs) -> None:
meta = kwargs.get("metadata") or {}
self.state.register_member(
node_id,
public_key=kwargs.get("public_key", meta.get("public_key")),
ip=kwargs.get("ip"),
asn=kwargs.get("asn"),
metadata=meta,
)
self._persist()
def remove_member(self, node_id: str) -> None:
self.state.forget_member(node_id)
self._persist()
def record_receipt(self, receipt: ReachabilityReceipt) -> None:
self.state.record_receipt(receipt)
self._persist()
def _persist(self) -> None:
key = MembershipKey(node_id=self.node_id)
self.store.put(
key=str(key),
fingerprint=key.fingerprint(),
value=self.state.to_dict(),
logical_counter=self.state.logical_counter,
merge_strategy=lambda a, b: MembershipState.from_dict(self.node_id, self.signer, a)
.merge(MembershipState.from_dict(self.node_id, self.signer, b))
.to_dict(),
)
def n_estimate(self) -> float:
return self.state.n_estimate()
def active_members(self) -> List[Dict[str, Any]]:
return self.state.active_members()