mdns: support running on macos

This commit is contained in:
ed 2022-11-17 20:18:24 +00:00
parent 5abe0c955c
commit 4ad4657774
5 changed files with 43 additions and 23 deletions

View file

@ -213,7 +213,7 @@ def init_E(E: EnvParams) -> None:
def get_srvname() -> str:
try:
ret: str = unicode(socket.gethostname()).split(".")[0].lower()
ret: str = unicode(socket.gethostname()).split(".")[0]
except:
ret = ""

View file

@ -165,6 +165,10 @@ class HttpSrv(object):
def listen(self, sck: socket.socket, nlisteners: int) -> None:
if self.args.j != 1:
# lost in the pickle; redefine
try:
sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except:
pass
sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sck.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sck.settimeout(None) # < does not inherit, ^ does

View file

@ -68,15 +68,20 @@ class MDNS(MCast):
self.ttl = 300
self.running = True
zs = self.args.name.lower() + ".local."
zs = self.args.name + ".local."
zs = zs.encode("ascii", "replace").decode("ascii", "replace")
self.hn = zs.replace("?", "_")
self.hn = "-".join(x for x in zs.split("?") if x) or (
"vault-{}".format(random.randint(1, 255))
)
self.lhn = self.hn.lower()
# requester ip -> (response deadline, srv, body):
self.q: dict[str, tuple[float, MDNS_Sck, bytes]] = {}
self.rx4 = CachedSet(0.42) # 3 probes @ 250..500..750 => 500ms span
self.rx6 = CachedSet(0.42)
self.svcs, self.sfqdns = self.build_svcs()
self.lsvcs = {k.lower(): v for k, v in self.svcs.items()}
self.lsfqdns = set([x.lower() for x in self.sfqdns])
self.probing = 0.0
self.unsolicited: list[float] = [] # scheduled announces on all nics
@ -211,8 +216,7 @@ class MDNS(MCast):
sreply.add_answer(r)
if not (have4 and have6) and not self.args.zm_noneg:
have = "AAAA" if have6 else "A"
ns = NSEC(self.hn, [have, "PTR", "SRV", "TXT"])
ns = NSEC(self.hn, ["AAAA" if have6 else "A"])
r = RR(self.hn, QTYPE.NSEC, DC.F_IN, 120, ns)
areply.add_ar(r)
if len(sreply.pack()) < 1400:
@ -294,7 +298,8 @@ class MDNS(MCast):
continue
if self.probing < time.time():
self.log("probe ok; starting announcements", 2)
t = "probe ok; announcing [{}]"
self.log(t.format(self.hn[:-1]), 2)
self.probing = 0
def stop(self, panic=False) -> None:
@ -331,7 +336,7 @@ class MDNS(MCast):
self.log(str(p))
# check for incoming probes for our hostname
cips = [U(x.rdata) for x in p.auth if U(x.rname).lower() == self.hn]
cips = [U(x.rdata) for x in p.auth if U(x.rname).lower() == self.lhn]
if cips and self.sips.isdisjoint(cips):
if not [x for x in cips if x not in ("::1", "127.0.0.1")]:
# avahi broadcasting 127.0.0.1-only packets
@ -350,7 +355,7 @@ class MDNS(MCast):
cips = [
U(x.rdata)
for x in p.rr
if U(x.rname).lower() == self.hn and x.rclass == DC.F_IN
if U(x.rname).lower() == self.lhn and x.rclass == DC.F_IN
]
if cips and self.sips.isdisjoint(cips):
if not [x for x in cips if x not in ("::1", "127.0.0.1")]:
@ -370,7 +375,7 @@ class MDNS(MCast):
# then a/aaaa records
for r in p.questions:
if U(r.qname).lower() != self.hn:
if U(r.qname).lower() != self.lhn:
continue
# gvfs keeps repeating itself
@ -378,7 +383,7 @@ class MDNS(MCast):
unicast = False
for r in p.rr:
rname = U(r.rname).lower()
if rname == self.hn:
if rname == self.lhn:
if r.ttl > 60:
found = True
if r.rclass == DC.F_IN:
@ -396,7 +401,7 @@ class MDNS(MCast):
# and service queries
for r in p.questions:
qname = U(r.qname).lower()
if qname in self.svcs or qname == "_services._dns-sd._udp.local.":
if qname in self.lsvcs or qname == "_services._dns-sd._udp.local.":
self.q[cip] = (deadline, srv, srv.bp_svc)
break
# heed rfc-7.1 if there was an announce in the past 12sec
@ -405,7 +410,7 @@ class MDNS(MCast):
if now < srv.last_tx + 12:
for r in p.rr:
rdata = U(r.rdata).lower()
if rdata in self.sfqdns:
if rdata in self.lsfqdns:
if r.ttl > 2250:
self.q.pop(cip, None)
break

View file

@ -7,7 +7,7 @@ import time
import ipaddress
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network
from .__init__ import TYPE_CHECKING
from .__init__ import TYPE_CHECKING, MACOS
from .util import min_ex, spack
if TYPE_CHECKING:
@ -72,7 +72,6 @@ class MCast(object):
def create_servers(self) -> list[str]:
bound: list[str] = []
ips = [x[0] for x in self.hub.tcpsrv.bound]
ips = list(set(ips))
if "::" in ips:
ips = [x for x in ips if x != "::"] + list(
@ -99,8 +98,10 @@ class MCast(object):
if not self.grp6:
ips = [x for x in ips if ":" not in x]
# discard non-linklocal ipv6
ips = list(set(ips))
all_selected = ips[:]
# discard non-linklocal ipv6
ips = [x for x in ips if ":" not in x or x.startswith("fe80")]
if not ips:
@ -196,19 +197,22 @@ class MCast(object):
self.b2srv[bip] = srv
self.b6.append(bip)
sck.bind((self.grp6 if srv.idx else "", self.port, 0, srv.idx))
grp = self.grp6 if srv.idx and not MACOS else ""
sck.bind((grp, self.port, 0, srv.idx))
bgrp = socket.inet_pton(socket.AF_INET6, self.grp6)
dev = spack(b"@I", srv.idx)
srv.mreq = bgrp + dev
if srv.idx != socket.INADDR_ANY:
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, dev)
self.hop(srv)
try:
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, 1)
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255)
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, 1)
except:
pass # macos
# macos
t = "failed to set IPv6 TTL/LOOP; announcements may not survive multiple switches/routers"
self.log(t, 3)
else:
if self.args.zmv:
self.log("v4({}) idx({})".format(srv.ip, srv.idx), 6)
@ -217,7 +221,8 @@ class MCast(object):
self.b2srv[bip] = srv
self.b4.append(bip)
sck.bind((self.grp4 if srv.idx else "", self.port))
grp = self.grp4 if srv.idx and not MACOS else ""
sck.bind((grp, self.port))
bgrp = socket.inet_aton(self.grp4)
dev = (
spack(b"=I", socket.INADDR_ANY)
@ -228,13 +233,15 @@ class MCast(object):
if srv.idx != socket.INADDR_ANY:
sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, dev)
self.hop(srv)
try:
sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 1)
sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 255)
sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 1)
except:
pass
# probably can't happen but dontcare if it does
t = "failed to set IPv4 TTL/LOOP; announcements may not survive multiple switches/routers"
self.log(t, 3)
self.hop(srv)
self.b4.sort(reverse=True)
self.b6.sort(reverse=True)

View file

@ -192,6 +192,10 @@ class TcpSrv(object):
def _listen(self, ip: str, port: int) -> None:
ipv = socket.AF_INET6 if ":" in ip else socket.AF_INET
srv = socket.socket(ipv, socket.SOCK_STREAM)
try:
srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except:
pass
srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
srv.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
srv.settimeout(None) # < does not inherit, ^ does