236 lines
8.7 KiB
Python
236 lines
8.7 KiB
Python
from __future__ import annotations
|
||
|
||
import ipaddress
|
||
import time
|
||
from dataclasses import dataclass
|
||
from typing import Dict, Any, Iterable, List, Optional, Tuple
|
||
|
||
from app.core._crypto.signer import Signer
|
||
from .config import dht_config
|
||
from .crdt import LWWSet, HyperLogLog
|
||
from .keys import MembershipKey
|
||
from .store import DHTStore
|
||
|
||
|
||
@dataclass
|
||
class ReachabilityReceipt:
|
||
target_id: str
|
||
issuer_id: str
|
||
asn: Optional[int]
|
||
timestamp: float
|
||
signature: str
|
||
|
||
def as_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"target_id": self.target_id,
|
||
"issuer_id": self.issuer_id,
|
||
"asn": self.asn,
|
||
"timestamp": self.timestamp,
|
||
"signature": self.signature,
|
||
}
|
||
|
||
|
||
def _ip_first_octet(host: str | None) -> Optional[int]:
|
||
if not host:
|
||
return None
|
||
try:
|
||
ip = ipaddress.ip_address(host)
|
||
return int(str(ip).split(".")[0])
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
class MembershipState:
|
||
def __init__(self, node_id: str, signer: Signer):
|
||
self.node_id = node_id
|
||
self.signer = signer
|
||
self.members = LWWSet()
|
||
self.receipts = LWWSet()
|
||
self.hll = HyperLogLog()
|
||
self.n_reports: Dict[str, float] = {}
|
||
self.logical_counter = 0
|
||
|
||
def _bump_counter(self) -> int:
|
||
self.logical_counter += 1
|
||
return self.logical_counter
|
||
|
||
def register_member(
|
||
self,
|
||
node_id: str,
|
||
public_key: str,
|
||
ip: str | None,
|
||
asn: Optional[int],
|
||
metadata: Dict[str, Any] | None = None,
|
||
timestamp: Optional[float] = None,
|
||
) -> None:
|
||
payload = {
|
||
"node_id": node_id,
|
||
"public_key": public_key,
|
||
"ip": ip,
|
||
"asn": asn,
|
||
"ip_first_octet": _ip_first_octet(ip),
|
||
"meta": metadata or {},
|
||
"last_update": timestamp or time.time(),
|
||
}
|
||
self.members.add(node_id, payload, logical_counter=self._bump_counter(), node_id=self.node_id, timestamp=timestamp)
|
||
self.hll.add(node_id)
|
||
|
||
def forget_member(self, node_id: str) -> None:
|
||
self.members.remove(node_id, logical_counter=self._bump_counter(), node_id=self.node_id)
|
||
|
||
def record_receipt(self, receipt: ReachabilityReceipt) -> None:
|
||
element_id = f"{receipt.target_id}:{receipt.issuer_id}"
|
||
self.receipts.add(
|
||
element_id,
|
||
receipt.as_dict(),
|
||
logical_counter=self._bump_counter(),
|
||
node_id=self.node_id,
|
||
timestamp=receipt.timestamp,
|
||
)
|
||
|
||
def report_local_population(self) -> None:
|
||
self.n_reports[self.node_id] = float(self.hll.estimate())
|
||
|
||
def merge(self, other: "MembershipState") -> "MembershipState":
|
||
self.members.merge(other.members)
|
||
self.receipts.merge(other.receipts)
|
||
self.hll.merge(other.hll)
|
||
for node_id, value in other.n_reports.items():
|
||
self.n_reports[node_id] = max(self.n_reports.get(node_id, 0.0), value)
|
||
self.logical_counter = max(self.logical_counter, other.logical_counter)
|
||
return self
|
||
|
||
def _unique_asn_for(self, node_id: str) -> Tuple[int, Iterable[int]]:
|
||
receipts = [
|
||
entry
|
||
for rid, entry in self.receipts.elements().items()
|
||
if entry.get("target_id") == node_id
|
||
]
|
||
unique_asn = {entry.get("asn") for entry in receipts if entry.get("asn") is not None}
|
||
return len(unique_asn), unique_asn
|
||
|
||
def reachability_ratio(self, node_id: str) -> float:
|
||
unique_count, _ = self._unique_asn_for(node_id)
|
||
if dht_config.min_receipts <= 0:
|
||
return 1.0
|
||
return min(1.0, unique_count / dht_config.min_receipts)
|
||
|
||
def active_members(self, include_islands: bool = False) -> List[Dict[str, Any]]:
|
||
now = time.time()
|
||
result = []
|
||
for node_id, data in self.members.elements().items():
|
||
last_update = data.get("last_update") or 0
|
||
if now - last_update > dht_config.membership_ttl:
|
||
continue
|
||
reachability = self.reachability_ratio(node_id)
|
||
if not include_islands and reachability < dht_config.default_q:
|
||
continue
|
||
enriched = dict(data)
|
||
enriched["reachability_ratio"] = reachability
|
||
result.append(enriched)
|
||
return result
|
||
|
||
def n_estimate(self) -> float:
|
||
self.report_local_population()
|
||
active_ids = {m["node_id"] for m in self.active_members(include_islands=True)}
|
||
filtered_reports = [
|
||
value for node_id, value in self.n_reports.items() if node_id in active_ids and self.reachability_ratio(node_id) >= dht_config.default_q
|
||
]
|
||
local_estimate = float(self.hll.estimate())
|
||
if filtered_reports:
|
||
return max(max(filtered_reports), local_estimate)
|
||
return local_estimate
|
||
|
||
def n_estimate_trusted(self, allowed_ids: set[str]) -> float:
|
||
"""Оценка размера сети только по trusted узлам.
|
||
Берём активных участников, пересекаем с allowed_ids и оцениваем по их числу
|
||
и по их N_local репортам (если доступны).
|
||
"""
|
||
self.report_local_population()
|
||
active_trusted = {m["node_id"] for m in self.active_members(include_islands=True) if m.get("node_id") in allowed_ids}
|
||
filtered_reports = [
|
||
value for node_id, value in self.n_reports.items()
|
||
if node_id in active_trusted and self.reachability_ratio(node_id) >= dht_config.default_q
|
||
]
|
||
# Для доверенных полагаемся на фактическое количество активных Trusted
|
||
local_estimate = float(len(active_trusted))
|
||
if filtered_reports:
|
||
return max(max(filtered_reports), local_estimate)
|
||
return local_estimate
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"members": self.members.to_dict(),
|
||
"receipts": self.receipts.to_dict(),
|
||
"hll": self.hll.to_dict(),
|
||
"reports": dict(self.n_reports),
|
||
"logical_counter": self.logical_counter,
|
||
}
|
||
|
||
@classmethod
|
||
def from_dict(cls, node_id: str, signer: Signer, data: Dict[str, Any]) -> "MembershipState":
|
||
inst = cls(node_id=node_id, signer=signer)
|
||
if data:
|
||
inst.members = LWWSet.from_dict(data.get("members") or {})
|
||
inst.receipts = LWWSet.from_dict(data.get("receipts") or {})
|
||
inst.hll = HyperLogLog.from_dict(data.get("hll") or {})
|
||
inst.n_reports = {str(k): float(v) for k, v in (data.get("reports") or {}).items()}
|
||
inst.logical_counter = int(data.get("logical_counter") or 0)
|
||
return inst
|
||
|
||
|
||
class MembershipManager:
|
||
def __init__(self, node_id: str, signer: Signer, store: DHTStore):
|
||
self.node_id = node_id
|
||
self.signer = signer
|
||
self.store = store
|
||
self.state = MembershipState(node_id=node_id, signer=signer)
|
||
|
||
def _merge_remote(self, data: Dict[str, Any]) -> None:
|
||
remote_state = MembershipState.from_dict(self.node_id, self.signer, data)
|
||
self.state.merge(remote_state)
|
||
|
||
def ingest_snapshot(self, payload: Dict[str, Any]) -> None:
|
||
self._merge_remote(payload)
|
||
|
||
def register_local(self, public_key: str, ip: str | None, asn: Optional[int], metadata: Dict[str, Any] | None = None) -> None:
|
||
self.state.register_member(self.node_id, public_key=public_key, ip=ip, asn=asn, metadata=metadata)
|
||
self._persist()
|
||
|
||
def update_member(self, node_id: str, **kwargs) -> None:
|
||
meta = kwargs.get("metadata") or {}
|
||
self.state.register_member(
|
||
node_id,
|
||
public_key=kwargs.get("public_key", meta.get("public_key")),
|
||
ip=kwargs.get("ip"),
|
||
asn=kwargs.get("asn"),
|
||
metadata=meta,
|
||
)
|
||
self._persist()
|
||
|
||
def remove_member(self, node_id: str) -> None:
|
||
self.state.forget_member(node_id)
|
||
self._persist()
|
||
|
||
def record_receipt(self, receipt: ReachabilityReceipt) -> None:
|
||
self.state.record_receipt(receipt)
|
||
self._persist()
|
||
|
||
def _persist(self) -> None:
|
||
key = MembershipKey(node_id=self.node_id)
|
||
self.store.put(
|
||
key=str(key),
|
||
fingerprint=key.fingerprint(),
|
||
value=self.state.to_dict(),
|
||
logical_counter=self.state.logical_counter,
|
||
merge_strategy=lambda a, b: MembershipState.from_dict(self.node_id, self.signer, a)
|
||
.merge(MembershipState.from_dict(self.node_id, self.signer, b))
|
||
.to_dict(),
|
||
)
|
||
|
||
def n_estimate(self) -> float:
|
||
return self.state.n_estimate()
|
||
|
||
def active_members(self) -> List[Dict[str, Any]]:
|
||
return self.state.active_members()
|