From 577d23f46057031f921bd641ed12a672e26deec8 Mon Sep 17 00:00:00 2001 From: ed Date: Wed, 18 Jan 2023 21:27:27 +0000 Subject: [PATCH] zeroconf: detect network change and reannounce --- copyparty/__main__.py | 1 + copyparty/mdns.py | 26 ++++++++++++++-------- copyparty/ssdp.py | 23 +++++++++++++++----- copyparty/svchub.py | 50 +++++++++++++++++++++++++++---------------- copyparty/tcpsrv.py | 24 +++++++++++++++++++++ copyparty/util.py | 3 +++ 6 files changed, 95 insertions(+), 32 deletions(-) diff --git a/copyparty/__main__.py b/copyparty/__main__.py index d11ab733..05e11463 100755 --- a/copyparty/__main__.py +++ b/copyparty/__main__.py @@ -692,6 +692,7 @@ def add_zeroconf(ap): ap2.add_argument("-z", action="store_true", help="enable all zeroconf backends (mdns, ssdp)") ap2.add_argument("--z-on", metavar="NETS", type=u, default="", help="enable zeroconf ONLY on the comma-separated list of subnets and/or interface names/indexes\n └─example: \033[32meth0, wlo1, virhost0, 192.168.123.0/24, fd00:fda::/96\033[0m") ap2.add_argument("--z-off", metavar="NETS", type=u, default="", help="disable zeroconf on the comma-separated list of subnets and/or interface names/indexes") + ap2.add_argument("--z-chk", metavar="SEC", type=int, default=10, help="check for network changes every SEC seconds (0=disable)") ap2.add_argument("-zv", action="store_true", help="verbose all zeroconf backends") ap2.add_argument("--mc-hop", metavar="SEC", type=int, default=0, help="rejoin multicast groups every SEC seconds (workaround for some switches/routers which cause mDNS to suddenly stop working after some time); try [\033[32m300\033[0m] or [\033[32m180\033[0m]") diff --git a/copyparty/mdns.py b/copyparty/mdns.py index ebd2d80c..928d5fd9 100644 --- a/copyparty/mdns.py +++ b/copyparty/mdns.py @@ -59,7 +59,7 @@ class MDNS_Sck(MC_Sck): class MDNS(MCast): - def __init__(self, hub: "SvcHub") -> None: + def __init__(self, hub: "SvcHub", ngen: int) -> None: al = hub.args grp4 = "" if al.zm6 else MDNS4 grp6 = "" if al.zm4 else MDNS6 @@ -67,7 +67,8 @@ class MDNS(MCast): hub, MDNS_Sck, al.zm_on, al.zm_off, grp4, grp6, 5353, hub.args.zmv ) self.srv: dict[socket.socket, MDNS_Sck] = {} - + self.logsrc = "mDNS-{}".format(ngen) + self.ngen = ngen self.ttl = 300 zs = self.args.name + ".local." @@ -90,7 +91,7 @@ class MDNS(MCast): self.defend: dict[MDNS_Sck, float] = {} # server -> deadline def log(self, msg: str, c: Union[int, str] = 0) -> None: - self.log_func("mDNS", msg, c) + self.log_func(self.logsrc, msg, c) def build_svcs(self) -> tuple[dict[str, dict[str, Any]], set[str]]: zms = self.args.zms @@ -288,12 +289,15 @@ class MDNS(MCast): rx: list[socket.socket] = rdy[0] # type: ignore self.rx4.cln() self.rx6.cln() + buf = b"" + addr = ("0", 0) for sck in rx: - buf, addr = sck.recvfrom(4096) try: + buf, addr = sck.recvfrom(4096) self.eat(buf, addr, sck) except: if not self.running: + self.log("stopped", 2) return t = "{} {} \033[33m|{}| {}\n{}".format( @@ -310,14 +314,18 @@ class MDNS(MCast): self.log(t.format(self.hn[:-1]), 2) self.probing = 0 + self.log("stopped", 2) + def stop(self, panic=False) -> None: self.running = False - if not panic: - for srv in self.srv.values(): - try: + for srv in self.srv.values(): + try: + if panic: + srv.sck.close() + else: srv.sck.sendto(srv.bp_bye, (srv.grp, 5353)) - except: - pass + except: + pass self.srv = {} diff --git a/copyparty/ssdp.py b/copyparty/ssdp.py index d918b933..ef23f156 100644 --- a/copyparty/ssdp.py +++ b/copyparty/ssdp.py @@ -89,19 +89,22 @@ class SSDPr(object): class SSDPd(MCast): """communicates with ssdp clients over multicast""" - def __init__(self, hub: "SvcHub") -> None: + def __init__(self, hub: "SvcHub", ngen: int) -> None: al = hub.args vinit = al.zsv and not al.zmv super(SSDPd, self).__init__( hub, SSDP_Sck, al.zs_on, al.zs_off, GRP, "", 1900, vinit ) self.srv: dict[socket.socket, SSDP_Sck] = {} + self.logsrc = "SSDP-{}".format(ngen) + self.ngen = ngen + self.rxc = CachedSet(0.7) self.txc = CachedSet(5) # win10: every 3 sec self.ptn_st = re.compile(b"\nst: *upnp:rootdevice", re.I) def log(self, msg: str, c: Union[int, str] = 0) -> None: - self.log_func("SSDP", msg, c) + self.log_func(self.logsrc, msg, c) def run(self) -> None: try: @@ -127,24 +130,34 @@ class SSDPd(MCast): self.log("listening") while self.running: - rdy = select.select(self.srv, [], [], 180) + rdy = select.select(self.srv, [], [], self.args.z_chk or 180) rx: list[socket.socket] = rdy[0] # type: ignore self.rxc.cln() + buf = b"" + addr = ("0", 0) for sck in rx: - buf, addr = sck.recvfrom(4096) try: + buf, addr = sck.recvfrom(4096) self.eat(buf, addr) except: if not self.running: - return + break t = "{} {} \033[33m|{}| {}\n{}".format( self.srv[sck].name, addr, len(buf), repr(buf)[2:-1], min_ex() ) self.log(t, 6) + self.log("stopped", 2) + def stop(self) -> None: self.running = False + for srv in self.srv.values(): + try: + srv.sck.close() + except: + pass + self.srv = {} def eat(self, buf: bytes, addr: tuple[str, int]) -> None: diff --git a/copyparty/svchub.py b/copyparty/svchub.py index fe01b607..3029d446 100644 --- a/copyparty/svchub.py +++ b/copyparty/svchub.py @@ -237,6 +237,7 @@ class SvcHub(object): if not args.zms: args.zms = zms + self.zc_ngen = 0 self.mdns: Optional["MDNS"] = None self.ssdp: Optional["SSDPd"] = None @@ -404,24 +405,10 @@ class SvcHub(object): def run(self) -> None: self.tcpsrv.run() - - if getattr(self.args, "zm", False): - try: - from .mdns import MDNS - - self.mdns = MDNS(self) - Daemon(self.mdns.run, "mdns") - except: - self.log("root", "mdns startup failed;\n" + min_ex(), 3) - - if getattr(self.args, "zs", False): - try: - from .ssdp import SSDPd - - self.ssdp = SSDPd(self) - Daemon(self.ssdp.run, "ssdp") - except: - self.log("root", "ssdp startup failed;\n" + min_ex(), 3) + if getattr(self.args, "z_chk", 0) and ( + getattr(self.args, "zm", False) or getattr(self.args, "zs", False) + ): + Daemon(self.tcpsrv.netmon, "netmon") Daemon(self.thr_httpsrv_up, "sig-hsrv-up2") @@ -453,6 +440,33 @@ class SvcHub(object): else: self.stop_thr() + def start_zeroconf(self) -> None: + self.zc_ngen += 1 + + if getattr(self.args, "zm", False): + try: + from .mdns import MDNS + + if self.mdns: + self.mdns.stop(True) + + self.mdns = MDNS(self, self.zc_ngen) + Daemon(self.mdns.run, "mdns") + except: + self.log("root", "mdns startup failed;\n" + min_ex(), 3) + + if getattr(self.args, "zs", False): + try: + from .ssdp import SSDPd + + if self.ssdp: + self.ssdp.stop() + + self.ssdp = SSDPd(self, self.zc_ngen) + Daemon(self.ssdp.run, "ssdp") + except: + self.log("root", "ssdp startup failed;\n" + min_ex(), 3) + def reload(self) -> str: if self.reloading: return "cannot reload; already in progress" diff --git a/copyparty/tcpsrv.py b/copyparty/tcpsrv.py index 122ee6a8..5ad25ed7 100644 --- a/copyparty/tcpsrv.py +++ b/copyparty/tcpsrv.py @@ -5,6 +5,7 @@ import os import re import socket import sys +import time from .__init__ import ANYWIN, PY2, TYPE_CHECKING, VT100, unicode from .stolen.qrcodegen import QrCode @@ -46,6 +47,8 @@ class TcpSrv(object): self.stopping = False self.srv: list[socket.socket] = [] self.bound: list[tuple[str, int]] = [] + self.netdevs: dict[str, Netdev] = {} + self.netlist = "" self.nsrv = 0 self.qr = "" pad = False @@ -268,7 +271,11 @@ class TcpSrv(object): self.srv = srvs self.bound = bound self.nsrv = len(srvs) + self._distribute_netdevs() + + def _distribute_netdevs(self): self.hub.broker.say("set_netdevs", self.netdevs) + self.hub.start_zeroconf() def shutdown(self) -> None: self.stopping = True @@ -280,6 +287,17 @@ class TcpSrv(object): self.log("tcpsrv", "ok bye") + def netmon(self): + while not self.stopping: + time.sleep(self.args.z_chk) + netdevs = self.detect_interfaces(self.args.i) + if not netdevs: + continue + + self.log("tcpsrv", "network change detected", 3) + self.netdevs = netdevs + self._distribute_netdevs() + def detect_interfaces(self, listen_ips: list[str]) -> dict[str, Netdev]: from .stolen.ifaddr import get_adapters @@ -300,6 +318,12 @@ class TcpSrv(object): except: pass + netlist = str(sorted(eps.items())) + if netlist == self.netlist and self.netdevs: + return {} + + self.netlist = netlist + if "0.0.0.0" not in listen_ips and "::" not in listen_ips: eps = {k: v for k, v in eps.items() if k.split("/")[0] in listen_ips} diff --git a/copyparty/util.py b/copyparty/util.py index c386c03d..c816d9be 100644 --- a/copyparty/util.py +++ b/copyparty/util.py @@ -379,6 +379,9 @@ class Netdev(object): def __str__(self): return "{}-{}{}".format(self.idx, self.name, self.desc) + def __repr__(self): + return "'{}-{}'".format(self.idx, self.name) + def __lt__(self, rhs): return str(self) < str(rhs)