zeroconf: detect network change and reannounce

This commit is contained in:
ed 2023-01-18 21:27:27 +00:00
parent 504e168486
commit 577d23f460
6 changed files with 95 additions and 32 deletions

View file

@ -692,6 +692,7 @@ def add_zeroconf(ap):
ap2.add_argument("-z", action="store_true", help="enable all zeroconf backends (mdns, ssdp)")
ap2.add_argument("--z-on", metavar="NETS", type=u, default="", help="enable zeroconf ONLY on the comma-separated list of subnets and/or interface names/indexes\n └─example: \033[32meth0, wlo1, virhost0, 192.168.123.0/24, fd00:fda::/96\033[0m")
ap2.add_argument("--z-off", metavar="NETS", type=u, default="", help="disable zeroconf on the comma-separated list of subnets and/or interface names/indexes")
ap2.add_argument("--z-chk", metavar="SEC", type=int, default=10, help="check for network changes every SEC seconds (0=disable)")
ap2.add_argument("-zv", action="store_true", help="verbose all zeroconf backends")
ap2.add_argument("--mc-hop", metavar="SEC", type=int, default=0, help="rejoin multicast groups every SEC seconds (workaround for some switches/routers which cause mDNS to suddenly stop working after some time); try [\033[32m300\033[0m] or [\033[32m180\033[0m]")

View file

@ -59,7 +59,7 @@ class MDNS_Sck(MC_Sck):
class MDNS(MCast):
def __init__(self, hub: "SvcHub") -> None:
def __init__(self, hub: "SvcHub", ngen: int) -> None:
al = hub.args
grp4 = "" if al.zm6 else MDNS4
grp6 = "" if al.zm4 else MDNS6
@ -67,7 +67,8 @@ class MDNS(MCast):
hub, MDNS_Sck, al.zm_on, al.zm_off, grp4, grp6, 5353, hub.args.zmv
)
self.srv: dict[socket.socket, MDNS_Sck] = {}
self.logsrc = "mDNS-{}".format(ngen)
self.ngen = ngen
self.ttl = 300
zs = self.args.name + ".local."
@ -90,7 +91,7 @@ class MDNS(MCast):
self.defend: dict[MDNS_Sck, float] = {} # server -> deadline
def log(self, msg: str, c: Union[int, str] = 0) -> None:
self.log_func("mDNS", msg, c)
self.log_func(self.logsrc, msg, c)
def build_svcs(self) -> tuple[dict[str, dict[str, Any]], set[str]]:
zms = self.args.zms
@ -288,12 +289,15 @@ class MDNS(MCast):
rx: list[socket.socket] = rdy[0] # type: ignore
self.rx4.cln()
self.rx6.cln()
buf = b""
addr = ("0", 0)
for sck in rx:
buf, addr = sck.recvfrom(4096)
try:
buf, addr = sck.recvfrom(4096)
self.eat(buf, addr, sck)
except:
if not self.running:
self.log("stopped", 2)
return
t = "{} {} \033[33m|{}| {}\n{}".format(
@ -310,11 +314,15 @@ class MDNS(MCast):
self.log(t.format(self.hn[:-1]), 2)
self.probing = 0
self.log("stopped", 2)
def stop(self, panic=False) -> None:
self.running = False
if not panic:
for srv in self.srv.values():
try:
if panic:
srv.sck.close()
else:
srv.sck.sendto(srv.bp_bye, (srv.grp, 5353))
except:
pass

View file

@ -89,19 +89,22 @@ class SSDPr(object):
class SSDPd(MCast):
"""communicates with ssdp clients over multicast"""
def __init__(self, hub: "SvcHub") -> None:
def __init__(self, hub: "SvcHub", ngen: int) -> None:
al = hub.args
vinit = al.zsv and not al.zmv
super(SSDPd, self).__init__(
hub, SSDP_Sck, al.zs_on, al.zs_off, GRP, "", 1900, vinit
)
self.srv: dict[socket.socket, SSDP_Sck] = {}
self.logsrc = "SSDP-{}".format(ngen)
self.ngen = ngen
self.rxc = CachedSet(0.7)
self.txc = CachedSet(5) # win10: every 3 sec
self.ptn_st = re.compile(b"\nst: *upnp:rootdevice", re.I)
def log(self, msg: str, c: Union[int, str] = 0) -> None:
self.log_func("SSDP", msg, c)
self.log_func(self.logsrc, msg, c)
def run(self) -> None:
try:
@ -127,24 +130,34 @@ class SSDPd(MCast):
self.log("listening")
while self.running:
rdy = select.select(self.srv, [], [], 180)
rdy = select.select(self.srv, [], [], self.args.z_chk or 180)
rx: list[socket.socket] = rdy[0] # type: ignore
self.rxc.cln()
buf = b""
addr = ("0", 0)
for sck in rx:
buf, addr = sck.recvfrom(4096)
try:
buf, addr = sck.recvfrom(4096)
self.eat(buf, addr)
except:
if not self.running:
return
break
t = "{} {} \033[33m|{}| {}\n{}".format(
self.srv[sck].name, addr, len(buf), repr(buf)[2:-1], min_ex()
)
self.log(t, 6)
self.log("stopped", 2)
def stop(self) -> None:
self.running = False
for srv in self.srv.values():
try:
srv.sck.close()
except:
pass
self.srv = {}
def eat(self, buf: bytes, addr: tuple[str, int]) -> None:

View file

@ -237,6 +237,7 @@ class SvcHub(object):
if not args.zms:
args.zms = zms
self.zc_ngen = 0
self.mdns: Optional["MDNS"] = None
self.ssdp: Optional["SSDPd"] = None
@ -404,24 +405,10 @@ class SvcHub(object):
def run(self) -> None:
self.tcpsrv.run()
if getattr(self.args, "zm", False):
try:
from .mdns import MDNS
self.mdns = MDNS(self)
Daemon(self.mdns.run, "mdns")
except:
self.log("root", "mdns startup failed;\n" + min_ex(), 3)
if getattr(self.args, "zs", False):
try:
from .ssdp import SSDPd
self.ssdp = SSDPd(self)
Daemon(self.ssdp.run, "ssdp")
except:
self.log("root", "ssdp startup failed;\n" + min_ex(), 3)
if getattr(self.args, "z_chk", 0) and (
getattr(self.args, "zm", False) or getattr(self.args, "zs", False)
):
Daemon(self.tcpsrv.netmon, "netmon")
Daemon(self.thr_httpsrv_up, "sig-hsrv-up2")
@ -453,6 +440,33 @@ class SvcHub(object):
else:
self.stop_thr()
def start_zeroconf(self) -> None:
self.zc_ngen += 1
if getattr(self.args, "zm", False):
try:
from .mdns import MDNS
if self.mdns:
self.mdns.stop(True)
self.mdns = MDNS(self, self.zc_ngen)
Daemon(self.mdns.run, "mdns")
except:
self.log("root", "mdns startup failed;\n" + min_ex(), 3)
if getattr(self.args, "zs", False):
try:
from .ssdp import SSDPd
if self.ssdp:
self.ssdp.stop()
self.ssdp = SSDPd(self, self.zc_ngen)
Daemon(self.ssdp.run, "ssdp")
except:
self.log("root", "ssdp startup failed;\n" + min_ex(), 3)
def reload(self) -> str:
if self.reloading:
return "cannot reload; already in progress"

View file

@ -5,6 +5,7 @@ import os
import re
import socket
import sys
import time
from .__init__ import ANYWIN, PY2, TYPE_CHECKING, VT100, unicode
from .stolen.qrcodegen import QrCode
@ -46,6 +47,8 @@ class TcpSrv(object):
self.stopping = False
self.srv: list[socket.socket] = []
self.bound: list[tuple[str, int]] = []
self.netdevs: dict[str, Netdev] = {}
self.netlist = ""
self.nsrv = 0
self.qr = ""
pad = False
@ -268,7 +271,11 @@ class TcpSrv(object):
self.srv = srvs
self.bound = bound
self.nsrv = len(srvs)
self._distribute_netdevs()
def _distribute_netdevs(self):
self.hub.broker.say("set_netdevs", self.netdevs)
self.hub.start_zeroconf()
def shutdown(self) -> None:
self.stopping = True
@ -280,6 +287,17 @@ class TcpSrv(object):
self.log("tcpsrv", "ok bye")
def netmon(self):
while not self.stopping:
time.sleep(self.args.z_chk)
netdevs = self.detect_interfaces(self.args.i)
if not netdevs:
continue
self.log("tcpsrv", "network change detected", 3)
self.netdevs = netdevs
self._distribute_netdevs()
def detect_interfaces(self, listen_ips: list[str]) -> dict[str, Netdev]:
from .stolen.ifaddr import get_adapters
@ -300,6 +318,12 @@ class TcpSrv(object):
except:
pass
netlist = str(sorted(eps.items()))
if netlist == self.netlist and self.netdevs:
return {}
self.netlist = netlist
if "0.0.0.0" not in listen_ips and "::" not in listen_ips:
eps = {k: v for k, v in eps.items() if k.split("/")[0] in listen_ips}

View file

@ -379,6 +379,9 @@ class Netdev(object):
def __str__(self):
return "{}-{}{}".format(self.idx, self.name, self.desc)
def __repr__(self):
return "'{}-{}'".format(self.idx, self.name)
def __lt__(self, rhs):
return str(self) < str(rhs)