mdns: support running on macos

This commit is contained in:
ed 2022-11-17 20:18:24 +00:00
parent 5abe0c955c
commit 4ad4657774
5 changed files with 43 additions and 23 deletions

View file

@ -213,7 +213,7 @@ def init_E(E: EnvParams) -> None:
def get_srvname() -> str: def get_srvname() -> str:
try: try:
ret: str = unicode(socket.gethostname()).split(".")[0].lower() ret: str = unicode(socket.gethostname()).split(".")[0]
except: except:
ret = "" ret = ""

View file

@ -165,6 +165,10 @@ class HttpSrv(object):
def listen(self, sck: socket.socket, nlisteners: int) -> None: def listen(self, sck: socket.socket, nlisteners: int) -> None:
if self.args.j != 1: if self.args.j != 1:
# lost in the pickle; redefine # lost in the pickle; redefine
try:
sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except:
pass
sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sck.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sck.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sck.settimeout(None) # < does not inherit, ^ does sck.settimeout(None) # < does not inherit, ^ does

View file

@ -68,15 +68,20 @@ class MDNS(MCast):
self.ttl = 300 self.ttl = 300
self.running = True self.running = True
zs = self.args.name.lower() + ".local." zs = self.args.name + ".local."
zs = zs.encode("ascii", "replace").decode("ascii", "replace") zs = zs.encode("ascii", "replace").decode("ascii", "replace")
self.hn = zs.replace("?", "_") self.hn = "-".join(x for x in zs.split("?") if x) or (
"vault-{}".format(random.randint(1, 255))
)
self.lhn = self.hn.lower()
# requester ip -> (response deadline, srv, body): # requester ip -> (response deadline, srv, body):
self.q: dict[str, tuple[float, MDNS_Sck, bytes]] = {} self.q: dict[str, tuple[float, MDNS_Sck, bytes]] = {}
self.rx4 = CachedSet(0.42) # 3 probes @ 250..500..750 => 500ms span self.rx4 = CachedSet(0.42) # 3 probes @ 250..500..750 => 500ms span
self.rx6 = CachedSet(0.42) self.rx6 = CachedSet(0.42)
self.svcs, self.sfqdns = self.build_svcs() self.svcs, self.sfqdns = self.build_svcs()
self.lsvcs = {k.lower(): v for k, v in self.svcs.items()}
self.lsfqdns = set([x.lower() for x in self.sfqdns])
self.probing = 0.0 self.probing = 0.0
self.unsolicited: list[float] = [] # scheduled announces on all nics self.unsolicited: list[float] = [] # scheduled announces on all nics
@ -211,8 +216,7 @@ class MDNS(MCast):
sreply.add_answer(r) sreply.add_answer(r)
if not (have4 and have6) and not self.args.zm_noneg: if not (have4 and have6) and not self.args.zm_noneg:
have = "AAAA" if have6 else "A" ns = NSEC(self.hn, ["AAAA" if have6 else "A"])
ns = NSEC(self.hn, [have, "PTR", "SRV", "TXT"])
r = RR(self.hn, QTYPE.NSEC, DC.F_IN, 120, ns) r = RR(self.hn, QTYPE.NSEC, DC.F_IN, 120, ns)
areply.add_ar(r) areply.add_ar(r)
if len(sreply.pack()) < 1400: if len(sreply.pack()) < 1400:
@ -294,7 +298,8 @@ class MDNS(MCast):
continue continue
if self.probing < time.time(): if self.probing < time.time():
self.log("probe ok; starting announcements", 2) t = "probe ok; announcing [{}]"
self.log(t.format(self.hn[:-1]), 2)
self.probing = 0 self.probing = 0
def stop(self, panic=False) -> None: def stop(self, panic=False) -> None:
@ -331,7 +336,7 @@ class MDNS(MCast):
self.log(str(p)) self.log(str(p))
# check for incoming probes for our hostname # check for incoming probes for our hostname
cips = [U(x.rdata) for x in p.auth if U(x.rname).lower() == self.hn] cips = [U(x.rdata) for x in p.auth if U(x.rname).lower() == self.lhn]
if cips and self.sips.isdisjoint(cips): if cips and self.sips.isdisjoint(cips):
if not [x for x in cips if x not in ("::1", "127.0.0.1")]: if not [x for x in cips if x not in ("::1", "127.0.0.1")]:
# avahi broadcasting 127.0.0.1-only packets # avahi broadcasting 127.0.0.1-only packets
@ -350,7 +355,7 @@ class MDNS(MCast):
cips = [ cips = [
U(x.rdata) U(x.rdata)
for x in p.rr for x in p.rr
if U(x.rname).lower() == self.hn and x.rclass == DC.F_IN if U(x.rname).lower() == self.lhn and x.rclass == DC.F_IN
] ]
if cips and self.sips.isdisjoint(cips): if cips and self.sips.isdisjoint(cips):
if not [x for x in cips if x not in ("::1", "127.0.0.1")]: if not [x for x in cips if x not in ("::1", "127.0.0.1")]:
@ -370,7 +375,7 @@ class MDNS(MCast):
# then a/aaaa records # then a/aaaa records
for r in p.questions: for r in p.questions:
if U(r.qname).lower() != self.hn: if U(r.qname).lower() != self.lhn:
continue continue
# gvfs keeps repeating itself # gvfs keeps repeating itself
@ -378,7 +383,7 @@ class MDNS(MCast):
unicast = False unicast = False
for r in p.rr: for r in p.rr:
rname = U(r.rname).lower() rname = U(r.rname).lower()
if rname == self.hn: if rname == self.lhn:
if r.ttl > 60: if r.ttl > 60:
found = True found = True
if r.rclass == DC.F_IN: if r.rclass == DC.F_IN:
@ -396,7 +401,7 @@ class MDNS(MCast):
# and service queries # and service queries
for r in p.questions: for r in p.questions:
qname = U(r.qname).lower() qname = U(r.qname).lower()
if qname in self.svcs or qname == "_services._dns-sd._udp.local.": if qname in self.lsvcs or qname == "_services._dns-sd._udp.local.":
self.q[cip] = (deadline, srv, srv.bp_svc) self.q[cip] = (deadline, srv, srv.bp_svc)
break break
# heed rfc-7.1 if there was an announce in the past 12sec # heed rfc-7.1 if there was an announce in the past 12sec
@ -405,7 +410,7 @@ class MDNS(MCast):
if now < srv.last_tx + 12: if now < srv.last_tx + 12:
for r in p.rr: for r in p.rr:
rdata = U(r.rdata).lower() rdata = U(r.rdata).lower()
if rdata in self.sfqdns: if rdata in self.lsfqdns:
if r.ttl > 2250: if r.ttl > 2250:
self.q.pop(cip, None) self.q.pop(cip, None)
break break

View file

@ -7,7 +7,7 @@ import time
import ipaddress import ipaddress
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network
from .__init__ import TYPE_CHECKING from .__init__ import TYPE_CHECKING, MACOS
from .util import min_ex, spack from .util import min_ex, spack
if TYPE_CHECKING: if TYPE_CHECKING:
@ -72,7 +72,6 @@ class MCast(object):
def create_servers(self) -> list[str]: def create_servers(self) -> list[str]:
bound: list[str] = [] bound: list[str] = []
ips = [x[0] for x in self.hub.tcpsrv.bound] ips = [x[0] for x in self.hub.tcpsrv.bound]
ips = list(set(ips))
if "::" in ips: if "::" in ips:
ips = [x for x in ips if x != "::"] + list( ips = [x for x in ips if x != "::"] + list(
@ -99,8 +98,10 @@ class MCast(object):
if not self.grp6: if not self.grp6:
ips = [x for x in ips if ":" not in x] ips = [x for x in ips if ":" not in x]
# discard non-linklocal ipv6 ips = list(set(ips))
all_selected = ips[:] all_selected = ips[:]
# discard non-linklocal ipv6
ips = [x for x in ips if ":" not in x or x.startswith("fe80")] ips = [x for x in ips if ":" not in x or x.startswith("fe80")]
if not ips: if not ips:
@ -196,19 +197,22 @@ class MCast(object):
self.b2srv[bip] = srv self.b2srv[bip] = srv
self.b6.append(bip) self.b6.append(bip)
sck.bind((self.grp6 if srv.idx else "", self.port, 0, srv.idx)) grp = self.grp6 if srv.idx and not MACOS else ""
sck.bind((grp, self.port, 0, srv.idx))
bgrp = socket.inet_pton(socket.AF_INET6, self.grp6) bgrp = socket.inet_pton(socket.AF_INET6, self.grp6)
dev = spack(b"@I", srv.idx) dev = spack(b"@I", srv.idx)
srv.mreq = bgrp + dev srv.mreq = bgrp + dev
if srv.idx != socket.INADDR_ANY: if srv.idx != socket.INADDR_ANY:
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, dev) sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, dev)
self.hop(srv)
try: try:
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, 1)
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255) sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255)
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, 1)
except: except:
pass # macos # macos
t = "failed to set IPv6 TTL/LOOP; announcements may not survive multiple switches/routers"
self.log(t, 3)
else: else:
if self.args.zmv: if self.args.zmv:
self.log("v4({}) idx({})".format(srv.ip, srv.idx), 6) self.log("v4({}) idx({})".format(srv.ip, srv.idx), 6)
@ -217,7 +221,8 @@ class MCast(object):
self.b2srv[bip] = srv self.b2srv[bip] = srv
self.b4.append(bip) self.b4.append(bip)
sck.bind((self.grp4 if srv.idx else "", self.port)) grp = self.grp4 if srv.idx and not MACOS else ""
sck.bind((grp, self.port))
bgrp = socket.inet_aton(self.grp4) bgrp = socket.inet_aton(self.grp4)
dev = ( dev = (
spack(b"=I", socket.INADDR_ANY) spack(b"=I", socket.INADDR_ANY)
@ -228,13 +233,15 @@ class MCast(object):
if srv.idx != socket.INADDR_ANY: if srv.idx != socket.INADDR_ANY:
sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, dev) sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, dev)
self.hop(srv)
try: try:
sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 1)
sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 255) sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 255)
sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 1)
except: except:
pass # probably can't happen but dontcare if it does
t = "failed to set IPv4 TTL/LOOP; announcements may not survive multiple switches/routers"
self.log(t, 3)
self.hop(srv)
self.b4.sort(reverse=True) self.b4.sort(reverse=True)
self.b6.sort(reverse=True) self.b6.sort(reverse=True)

View file

@ -192,6 +192,10 @@ class TcpSrv(object):
def _listen(self, ip: str, port: int) -> None: def _listen(self, ip: str, port: int) -> None:
ipv = socket.AF_INET6 if ":" in ip else socket.AF_INET ipv = socket.AF_INET6 if ":" in ip else socket.AF_INET
srv = socket.socket(ipv, socket.SOCK_STREAM) srv = socket.socket(ipv, socket.SOCK_STREAM)
try:
srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except:
pass
srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
srv.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) srv.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
srv.settimeout(None) # < does not inherit, ^ does srv.settimeout(None) # < does not inherit, ^ does