mirror of
https://github.com/9001/copyparty.git
synced 2025-08-18 01:22:13 -06:00
zeroconf: detect network change and reannounce
This commit is contained in:
parent
504e168486
commit
577d23f460
|
@ -692,6 +692,7 @@ def add_zeroconf(ap):
|
|||
ap2.add_argument("-z", action="store_true", help="enable all zeroconf backends (mdns, ssdp)")
|
||||
ap2.add_argument("--z-on", metavar="NETS", type=u, default="", help="enable zeroconf ONLY on the comma-separated list of subnets and/or interface names/indexes\n └─example: \033[32meth0, wlo1, virhost0, 192.168.123.0/24, fd00:fda::/96\033[0m")
|
||||
ap2.add_argument("--z-off", metavar="NETS", type=u, default="", help="disable zeroconf on the comma-separated list of subnets and/or interface names/indexes")
|
||||
ap2.add_argument("--z-chk", metavar="SEC", type=int, default=10, help="check for network changes every SEC seconds (0=disable)")
|
||||
ap2.add_argument("-zv", action="store_true", help="verbose all zeroconf backends")
|
||||
ap2.add_argument("--mc-hop", metavar="SEC", type=int, default=0, help="rejoin multicast groups every SEC seconds (workaround for some switches/routers which cause mDNS to suddenly stop working after some time); try [\033[32m300\033[0m] or [\033[32m180\033[0m]")
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ class MDNS_Sck(MC_Sck):
|
|||
|
||||
|
||||
class MDNS(MCast):
|
||||
def __init__(self, hub: "SvcHub") -> None:
|
||||
def __init__(self, hub: "SvcHub", ngen: int) -> None:
|
||||
al = hub.args
|
||||
grp4 = "" if al.zm6 else MDNS4
|
||||
grp6 = "" if al.zm4 else MDNS6
|
||||
|
@ -67,7 +67,8 @@ class MDNS(MCast):
|
|||
hub, MDNS_Sck, al.zm_on, al.zm_off, grp4, grp6, 5353, hub.args.zmv
|
||||
)
|
||||
self.srv: dict[socket.socket, MDNS_Sck] = {}
|
||||
|
||||
self.logsrc = "mDNS-{}".format(ngen)
|
||||
self.ngen = ngen
|
||||
self.ttl = 300
|
||||
|
||||
zs = self.args.name + ".local."
|
||||
|
@ -90,7 +91,7 @@ class MDNS(MCast):
|
|||
self.defend: dict[MDNS_Sck, float] = {} # server -> deadline
|
||||
|
||||
def log(self, msg: str, c: Union[int, str] = 0) -> None:
|
||||
self.log_func("mDNS", msg, c)
|
||||
self.log_func(self.logsrc, msg, c)
|
||||
|
||||
def build_svcs(self) -> tuple[dict[str, dict[str, Any]], set[str]]:
|
||||
zms = self.args.zms
|
||||
|
@ -288,12 +289,15 @@ class MDNS(MCast):
|
|||
rx: list[socket.socket] = rdy[0] # type: ignore
|
||||
self.rx4.cln()
|
||||
self.rx6.cln()
|
||||
buf = b""
|
||||
addr = ("0", 0)
|
||||
for sck in rx:
|
||||
buf, addr = sck.recvfrom(4096)
|
||||
try:
|
||||
buf, addr = sck.recvfrom(4096)
|
||||
self.eat(buf, addr, sck)
|
||||
except:
|
||||
if not self.running:
|
||||
self.log("stopped", 2)
|
||||
return
|
||||
|
||||
t = "{} {} \033[33m|{}| {}\n{}".format(
|
||||
|
@ -310,14 +314,18 @@ class MDNS(MCast):
|
|||
self.log(t.format(self.hn[:-1]), 2)
|
||||
self.probing = 0
|
||||
|
||||
self.log("stopped", 2)
|
||||
|
||||
def stop(self, panic=False) -> None:
|
||||
self.running = False
|
||||
if not panic:
|
||||
for srv in self.srv.values():
|
||||
try:
|
||||
for srv in self.srv.values():
|
||||
try:
|
||||
if panic:
|
||||
srv.sck.close()
|
||||
else:
|
||||
srv.sck.sendto(srv.bp_bye, (srv.grp, 5353))
|
||||
except:
|
||||
pass
|
||||
except:
|
||||
pass
|
||||
|
||||
self.srv = {}
|
||||
|
||||
|
|
|
@ -89,19 +89,22 @@ class SSDPr(object):
|
|||
class SSDPd(MCast):
|
||||
"""communicates with ssdp clients over multicast"""
|
||||
|
||||
def __init__(self, hub: "SvcHub") -> None:
|
||||
def __init__(self, hub: "SvcHub", ngen: int) -> None:
|
||||
al = hub.args
|
||||
vinit = al.zsv and not al.zmv
|
||||
super(SSDPd, self).__init__(
|
||||
hub, SSDP_Sck, al.zs_on, al.zs_off, GRP, "", 1900, vinit
|
||||
)
|
||||
self.srv: dict[socket.socket, SSDP_Sck] = {}
|
||||
self.logsrc = "SSDP-{}".format(ngen)
|
||||
self.ngen = ngen
|
||||
|
||||
self.rxc = CachedSet(0.7)
|
||||
self.txc = CachedSet(5) # win10: every 3 sec
|
||||
self.ptn_st = re.compile(b"\nst: *upnp:rootdevice", re.I)
|
||||
|
||||
def log(self, msg: str, c: Union[int, str] = 0) -> None:
|
||||
self.log_func("SSDP", msg, c)
|
||||
self.log_func(self.logsrc, msg, c)
|
||||
|
||||
def run(self) -> None:
|
||||
try:
|
||||
|
@ -127,24 +130,34 @@ class SSDPd(MCast):
|
|||
|
||||
self.log("listening")
|
||||
while self.running:
|
||||
rdy = select.select(self.srv, [], [], 180)
|
||||
rdy = select.select(self.srv, [], [], self.args.z_chk or 180)
|
||||
rx: list[socket.socket] = rdy[0] # type: ignore
|
||||
self.rxc.cln()
|
||||
buf = b""
|
||||
addr = ("0", 0)
|
||||
for sck in rx:
|
||||
buf, addr = sck.recvfrom(4096)
|
||||
try:
|
||||
buf, addr = sck.recvfrom(4096)
|
||||
self.eat(buf, addr)
|
||||
except:
|
||||
if not self.running:
|
||||
return
|
||||
break
|
||||
|
||||
t = "{} {} \033[33m|{}| {}\n{}".format(
|
||||
self.srv[sck].name, addr, len(buf), repr(buf)[2:-1], min_ex()
|
||||
)
|
||||
self.log(t, 6)
|
||||
|
||||
self.log("stopped", 2)
|
||||
|
||||
def stop(self) -> None:
|
||||
self.running = False
|
||||
for srv in self.srv.values():
|
||||
try:
|
||||
srv.sck.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
self.srv = {}
|
||||
|
||||
def eat(self, buf: bytes, addr: tuple[str, int]) -> None:
|
||||
|
|
|
@ -237,6 +237,7 @@ class SvcHub(object):
|
|||
if not args.zms:
|
||||
args.zms = zms
|
||||
|
||||
self.zc_ngen = 0
|
||||
self.mdns: Optional["MDNS"] = None
|
||||
self.ssdp: Optional["SSDPd"] = None
|
||||
|
||||
|
@ -404,24 +405,10 @@ class SvcHub(object):
|
|||
|
||||
def run(self) -> None:
|
||||
self.tcpsrv.run()
|
||||
|
||||
if getattr(self.args, "zm", False):
|
||||
try:
|
||||
from .mdns import MDNS
|
||||
|
||||
self.mdns = MDNS(self)
|
||||
Daemon(self.mdns.run, "mdns")
|
||||
except:
|
||||
self.log("root", "mdns startup failed;\n" + min_ex(), 3)
|
||||
|
||||
if getattr(self.args, "zs", False):
|
||||
try:
|
||||
from .ssdp import SSDPd
|
||||
|
||||
self.ssdp = SSDPd(self)
|
||||
Daemon(self.ssdp.run, "ssdp")
|
||||
except:
|
||||
self.log("root", "ssdp startup failed;\n" + min_ex(), 3)
|
||||
if getattr(self.args, "z_chk", 0) and (
|
||||
getattr(self.args, "zm", False) or getattr(self.args, "zs", False)
|
||||
):
|
||||
Daemon(self.tcpsrv.netmon, "netmon")
|
||||
|
||||
Daemon(self.thr_httpsrv_up, "sig-hsrv-up2")
|
||||
|
||||
|
@ -453,6 +440,33 @@ class SvcHub(object):
|
|||
else:
|
||||
self.stop_thr()
|
||||
|
||||
def start_zeroconf(self) -> None:
|
||||
self.zc_ngen += 1
|
||||
|
||||
if getattr(self.args, "zm", False):
|
||||
try:
|
||||
from .mdns import MDNS
|
||||
|
||||
if self.mdns:
|
||||
self.mdns.stop(True)
|
||||
|
||||
self.mdns = MDNS(self, self.zc_ngen)
|
||||
Daemon(self.mdns.run, "mdns")
|
||||
except:
|
||||
self.log("root", "mdns startup failed;\n" + min_ex(), 3)
|
||||
|
||||
if getattr(self.args, "zs", False):
|
||||
try:
|
||||
from .ssdp import SSDPd
|
||||
|
||||
if self.ssdp:
|
||||
self.ssdp.stop()
|
||||
|
||||
self.ssdp = SSDPd(self, self.zc_ngen)
|
||||
Daemon(self.ssdp.run, "ssdp")
|
||||
except:
|
||||
self.log("root", "ssdp startup failed;\n" + min_ex(), 3)
|
||||
|
||||
def reload(self) -> str:
|
||||
if self.reloading:
|
||||
return "cannot reload; already in progress"
|
||||
|
|
|
@ -5,6 +5,7 @@ import os
|
|||
import re
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
|
||||
from .__init__ import ANYWIN, PY2, TYPE_CHECKING, VT100, unicode
|
||||
from .stolen.qrcodegen import QrCode
|
||||
|
@ -46,6 +47,8 @@ class TcpSrv(object):
|
|||
self.stopping = False
|
||||
self.srv: list[socket.socket] = []
|
||||
self.bound: list[tuple[str, int]] = []
|
||||
self.netdevs: dict[str, Netdev] = {}
|
||||
self.netlist = ""
|
||||
self.nsrv = 0
|
||||
self.qr = ""
|
||||
pad = False
|
||||
|
@ -268,7 +271,11 @@ class TcpSrv(object):
|
|||
self.srv = srvs
|
||||
self.bound = bound
|
||||
self.nsrv = len(srvs)
|
||||
self._distribute_netdevs()
|
||||
|
||||
def _distribute_netdevs(self):
|
||||
self.hub.broker.say("set_netdevs", self.netdevs)
|
||||
self.hub.start_zeroconf()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.stopping = True
|
||||
|
@ -280,6 +287,17 @@ class TcpSrv(object):
|
|||
|
||||
self.log("tcpsrv", "ok bye")
|
||||
|
||||
def netmon(self):
|
||||
while not self.stopping:
|
||||
time.sleep(self.args.z_chk)
|
||||
netdevs = self.detect_interfaces(self.args.i)
|
||||
if not netdevs:
|
||||
continue
|
||||
|
||||
self.log("tcpsrv", "network change detected", 3)
|
||||
self.netdevs = netdevs
|
||||
self._distribute_netdevs()
|
||||
|
||||
def detect_interfaces(self, listen_ips: list[str]) -> dict[str, Netdev]:
|
||||
from .stolen.ifaddr import get_adapters
|
||||
|
||||
|
@ -300,6 +318,12 @@ class TcpSrv(object):
|
|||
except:
|
||||
pass
|
||||
|
||||
netlist = str(sorted(eps.items()))
|
||||
if netlist == self.netlist and self.netdevs:
|
||||
return {}
|
||||
|
||||
self.netlist = netlist
|
||||
|
||||
if "0.0.0.0" not in listen_ips and "::" not in listen_ips:
|
||||
eps = {k: v for k, v in eps.items() if k.split("/")[0] in listen_ips}
|
||||
|
||||
|
|
|
@ -379,6 +379,9 @@ class Netdev(object):
|
|||
def __str__(self):
|
||||
return "{}-{}{}".format(self.idx, self.name, self.desc)
|
||||
|
||||
def __repr__(self):
|
||||
return "'{}-{}'".format(self.idx, self.name)
|
||||
|
||||
def __lt__(self, rhs):
|
||||
return str(self) < str(rhs)
|
||||
|
||||
|
|
Loading…
Reference in a new issue