from __future__ import annotations import ipaddress import time from dataclasses import dataclass from typing import Dict, Any, Iterable, List, Optional, Tuple from app.core._crypto.signer import Signer from .config import dht_config from .crdt import LWWSet, HyperLogLog from .keys import MembershipKey from .store import DHTStore @dataclass class ReachabilityReceipt: target_id: str issuer_id: str asn: Optional[int] timestamp: float signature: str def as_dict(self) -> Dict[str, Any]: return { "target_id": self.target_id, "issuer_id": self.issuer_id, "asn": self.asn, "timestamp": self.timestamp, "signature": self.signature, } def _ip_first_octet(host: str | None) -> Optional[int]: if not host: return None try: ip = ipaddress.ip_address(host) return int(str(ip).split(".")[0]) except Exception: return None class MembershipState: def __init__(self, node_id: str, signer: Signer): self.node_id = node_id self.signer = signer self.members = LWWSet() self.receipts = LWWSet() self.hll = HyperLogLog() self.n_reports: Dict[str, float] = {} self.logical_counter = 0 def _bump_counter(self) -> int: self.logical_counter += 1 return self.logical_counter def register_member( self, node_id: str, public_key: str, ip: str | None, asn: Optional[int], metadata: Dict[str, Any] | None = None, timestamp: Optional[float] = None, ) -> None: payload = { "node_id": node_id, "public_key": public_key, "ip": ip, "asn": asn, "ip_first_octet": _ip_first_octet(ip), "meta": metadata or {}, "last_update": timestamp or time.time(), } self.members.add(node_id, payload, logical_counter=self._bump_counter(), node_id=self.node_id, timestamp=timestamp) self.hll.add(node_id) def forget_member(self, node_id: str) -> None: self.members.remove(node_id, logical_counter=self._bump_counter(), node_id=self.node_id) def record_receipt(self, receipt: ReachabilityReceipt) -> None: element_id = f"{receipt.target_id}:{receipt.issuer_id}" self.receipts.add( element_id, receipt.as_dict(), logical_counter=self._bump_counter(), node_id=self.node_id, timestamp=receipt.timestamp, ) def report_local_population(self) -> None: self.n_reports[self.node_id] = float(self.hll.estimate()) def merge(self, other: "MembershipState") -> "MembershipState": self.members.merge(other.members) self.receipts.merge(other.receipts) self.hll.merge(other.hll) for node_id, value in other.n_reports.items(): self.n_reports[node_id] = max(self.n_reports.get(node_id, 0.0), value) self.logical_counter = max(self.logical_counter, other.logical_counter) return self def _unique_asn_for(self, node_id: str) -> Tuple[int, Iterable[int]]: receipts = [ entry for rid, entry in self.receipts.elements().items() if entry.get("target_id") == node_id ] unique_asn = {entry.get("asn") for entry in receipts if entry.get("asn") is not None} return len(unique_asn), unique_asn def reachability_ratio(self, node_id: str) -> float: unique_count, _ = self._unique_asn_for(node_id) if dht_config.min_receipts <= 0: return 1.0 return min(1.0, unique_count / dht_config.min_receipts) def active_members(self, include_islands: bool = False) -> List[Dict[str, Any]]: now = time.time() result = [] for node_id, data in self.members.elements().items(): last_update = data.get("last_update") or 0 if now - last_update > dht_config.membership_ttl: continue reachability = self.reachability_ratio(node_id) if not include_islands and reachability < dht_config.default_q: continue enriched = dict(data) enriched["reachability_ratio"] = reachability result.append(enriched) return result def n_estimate(self) -> float: self.report_local_population() active_ids = {m["node_id"] for m in self.active_members(include_islands=True)} filtered_reports = [ value for node_id, value in self.n_reports.items() if node_id in active_ids and self.reachability_ratio(node_id) >= dht_config.default_q ] local_estimate = float(self.hll.estimate()) if filtered_reports: return max(max(filtered_reports), local_estimate) return local_estimate def n_estimate_trusted(self, allowed_ids: set[str]) -> float: """Оценка размера сети только по trusted узлам. Берём активных участников, пересекаем с allowed_ids и оцениваем по их числу и по их N_local репортам (если доступны). """ self.report_local_population() active_trusted = {m["node_id"] for m in self.active_members(include_islands=True) if m.get("node_id") in allowed_ids} filtered_reports = [ value for node_id, value in self.n_reports.items() if node_id in active_trusted and self.reachability_ratio(node_id) >= dht_config.default_q ] # Для доверенных полагаемся на фактическое количество активных Trusted local_estimate = float(len(active_trusted)) if filtered_reports: return max(max(filtered_reports), local_estimate) return local_estimate def to_dict(self) -> Dict[str, Any]: return { "members": self.members.to_dict(), "receipts": self.receipts.to_dict(), "hll": self.hll.to_dict(), "reports": dict(self.n_reports), "logical_counter": self.logical_counter, } @classmethod def from_dict(cls, node_id: str, signer: Signer, data: Dict[str, Any]) -> "MembershipState": inst = cls(node_id=node_id, signer=signer) if data: inst.members = LWWSet.from_dict(data.get("members") or {}) inst.receipts = LWWSet.from_dict(data.get("receipts") or {}) inst.hll = HyperLogLog.from_dict(data.get("hll") or {}) inst.n_reports = {str(k): float(v) for k, v in (data.get("reports") or {}).items()} inst.logical_counter = int(data.get("logical_counter") or 0) return inst class MembershipManager: def __init__(self, node_id: str, signer: Signer, store: DHTStore): self.node_id = node_id self.signer = signer self.store = store self.state = MembershipState(node_id=node_id, signer=signer) def _merge_remote(self, data: Dict[str, Any]) -> None: remote_state = MembershipState.from_dict(self.node_id, self.signer, data) self.state.merge(remote_state) def ingest_snapshot(self, payload: Dict[str, Any]) -> None: self._merge_remote(payload) def register_local(self, public_key: str, ip: str | None, asn: Optional[int], metadata: Dict[str, Any] | None = None) -> None: self.state.register_member(self.node_id, public_key=public_key, ip=ip, asn=asn, metadata=metadata) self._persist() def update_member(self, node_id: str, **kwargs) -> None: meta = kwargs.get("metadata") or {} self.state.register_member( node_id, public_key=kwargs.get("public_key", meta.get("public_key")), ip=kwargs.get("ip"), asn=kwargs.get("asn"), metadata=meta, ) self._persist() def remove_member(self, node_id: str) -> None: self.state.forget_member(node_id) self._persist() def record_receipt(self, receipt: ReachabilityReceipt) -> None: self.state.record_receipt(receipt) self._persist() def _persist(self) -> None: key = MembershipKey(node_id=self.node_id) self.store.put( key=str(key), fingerprint=key.fingerprint(), value=self.state.to_dict(), logical_counter=self.state.logical_counter, merge_strategy=lambda a, b: MembershipState.from_dict(self.node_id, self.signer, a) .merge(MembershipState.from_dict(self.node_id, self.signer, b)) .to_dict(), ) def n_estimate(self) -> float: return self.state.n_estimate() def active_members(self) -> List[Dict[str, Any]]: return self.state.active_members()