mdns ipv6 fixes; now works on ie11/safari, not linux:

* subscribe/announce on LL only
* add NSEC records if 4/6-only
This commit is contained in:
ed 2022-11-15 06:39:53 +00:00
parent 37c1cab726
commit 8829f56d4c
6 changed files with 91 additions and 25 deletions

View file

@ -19,6 +19,7 @@ from .stolen.dnslib import (
QTYPE, QTYPE,
A, A,
AAAA, AAAA,
NSEC,
SRV, SRV,
PTR, PTR,
TXT, TXT,
@ -41,11 +42,12 @@ class MDNS_Sck(MC_Sck):
self, self,
sck: socket.socket, sck: socket.socket,
idx: int, idx: int,
name: str,
grp: str, grp: str,
ip: str, ip: str,
net: Union[IPv4Network, IPv6Network], net: Union[IPv4Network, IPv6Network],
): ):
super(MDNS_Sck, self).__init__(sck, idx, grp, ip, net) super(MDNS_Sck, self).__init__(sck, idx, name, grp, ip, net)
self.bp_probe = b"" self.bp_probe = b""
self.bp_ip = b"" self.bp_ip = b""
@ -143,6 +145,16 @@ class MDNS(MCast):
sreply = DNSRecord(DNSHeader(0, 0x8400)) sreply = DNSRecord(DNSHeader(0, 0x8400))
bye = DNSRecord(DNSHeader(0, 0x8400)) bye = DNSRecord(DNSHeader(0, 0x8400))
have4 = have6 = False
for s2 in self.srv.values():
if srv.idx != s2.idx:
continue
if s2.v6:
have6 = True
else:
have4 = True
for ip in srv.ips: for ip in srv.ips:
if ":" in ip: if ":" in ip:
qt = QTYPE.AAAA qt = QTYPE.AAAA
@ -162,6 +174,12 @@ class MDNS(MCast):
sreply.add_answer(r120) sreply.add_answer(r120)
bye.add_answer(r0) bye.add_answer(r0)
if not have4 or not have6:
ns = NSEC(self.hn, ["AAAA" if have4 else "A"])
r = RR(self.hn, QTYPE.NSEC, DC.F_IN, 120, ns)
areply.add_ar(r)
sreply.add_ar(r)
for sclass, props in self.svcs.items(): for sclass, props in self.svcs.items():
sname = props["name"] sname = props["name"]
sport = props["port"] sport = props["port"]
@ -255,13 +273,13 @@ class MDNS(MCast):
rx: list[socket.socket] = rdy[0] # type: ignore rx: list[socket.socket] = rdy[0] # type: ignore
self.rx4.cln() self.rx4.cln()
self.rx6.cln() self.rx6.cln()
for srv in rx: for sck in rx:
buf, addr = srv.recvfrom(4096) buf, addr = sck.recvfrom(4096)
try: try:
self.eat(buf, addr) self.eat(buf, addr, sck)
except: except:
t = "{} \033[33m|{}| {}\n{}".format( t = "{} {} \033[33m|{}| {}\n{}".format(
addr, len(buf), repr(buf)[2:-1], min_ex() self.srv[sck].name, addr, len(buf), repr(buf)[2:-1], min_ex()
) )
self.log(t, 6) self.log(t, 6)
@ -279,9 +297,11 @@ class MDNS(MCast):
for srv in self.srv.values(): for srv in self.srv.values():
srv.sck.sendto(srv.bp_bye, (srv.grp, 5353)) srv.sck.sendto(srv.bp_bye, (srv.grp, 5353))
def eat(self, buf: bytes, addr: tuple[str, int]): self.srv = {}
def eat(self, buf: bytes, addr: tuple[str, int], sck: socket.socket):
cip = addr[0] cip = addr[0]
if cip.startswith("fe80") or cip.startswith("169.254"): if cip.startswith("169.254"):
return return
v6 = ":" in cip v6 = ":" in cip
@ -290,14 +310,15 @@ class MDNS(MCast):
return return
cache.add(buf) cache.add(buf)
srv: Optional[MDNS_Sck] = self.map_client(cip) # type: ignore srv: Optional[MDNS_Sck] = self.srv[sck] if v6 else self.map_client(cip) # type: ignore
if not srv: if not srv:
return return
now = time.time() now = time.time()
if self.args.zmv: if self.args.zmv:
self.log("[{}] \033[36m{} \033[0m|{}|".format(srv.ip, cip, len(buf)), "90") t = "{} [{}] \033[36m{} \033[0m|{}|"
self.log(t.format(srv.name, srv.ip, cip, len(buf)), "90")
p = DNSRecord.parse(buf) p = DNSRecord.parse(buf)
if self.args.zmvv: if self.args.zmvv:
@ -336,8 +357,8 @@ class MDNS(MCast):
else: else:
t += "Emergency stop; hostname '{}' got stolen" t += "Emergency stop; hostname '{}' got stolen"
t += "! Use --name to set another hostname.\n\nName taken by {}\n\nYour IPs: {}\n" t += " on {}! Use --name to set another hostname.\n\nName taken by {}\n\nYour IPs: {}\n"
self.log(t.format(self.args.name, cips, list(self.sips)), 1) self.log(t.format(self.args.name, srv.name, cips, list(self.sips)), 1)
self.stop(True) self.stop(True)
return return
@ -348,11 +369,17 @@ class MDNS(MCast):
# gvfs keeps repeating itself # gvfs keeps repeating itself
found = False found = False
unicast = False
for r in p.rr: for r in p.rr:
rname = U(r.rname).lower() rname = U(r.rname).lower()
if rname == self.hn and r.ttl > 60: if rname == self.hn:
found = True if r.ttl > 60:
break found = True
if r.rclass == DC.F_IN:
unicast = True
if unicast:
srv.sck.sendto(srv.bp_ip, (cip, 5353))
if not found: if not found:
self.q[cip] = (0, srv, srv.bp_ip) self.q[cip] = (0, srv, srv.bp_ip)

View file

@ -26,12 +26,14 @@ class MC_Sck(object):
self, self,
sck: socket.socket, sck: socket.socket,
idx: int, idx: int,
name: str,
grp: str, grp: str,
ip: str, ip: str,
net: Union[IPv4Network, IPv6Network], net: Union[IPv4Network, IPv6Network],
): ):
self.sck = sck self.sck = sck
self.idx = idx self.idx = idx
self.name = name
self.grp = grp self.grp = grp
self.mreq = b"" self.mreq = b""
self.ip = ip self.ip = ip
@ -55,7 +57,7 @@ class MCast(object):
self.port = port self.port = port
self.srv: dict[socket.socket, MC_Sck] = {} # listening sockets self.srv: dict[socket.socket, MC_Sck] = {} # listening sockets
self.sips: set[str] = set() # all listening ips self.sips: set[str] = set() # all listening ips (including failed attempts)
self.b2srv: dict[bytes, MC_Sck] = {} # binary-ip -> server socket self.b2srv: dict[bytes, MC_Sck] = {} # binary-ip -> server socket
self.b4: list[bytes] = [] # sorted list of binary-ips self.b4: list[bytes] = [] # sorted list of binary-ips
self.b6: list[bytes] = [] # sorted list of binary-ips self.b6: list[bytes] = [] # sorted list of binary-ips
@ -82,6 +84,7 @@ class MCast(object):
ips = [x for x in ips if x not in ("::1", "127.0.0.1")] ips = [x for x in ips if x not in ("::1", "127.0.0.1")]
# ip -> ip/prefix
ips = [ ips = [
[x for x in self.hub.tcpsrv.netdevs if x.startswith(y + "/")][0] [x for x in self.hub.tcpsrv.netdevs if x.startswith(y + "/")][0]
for y in ips for y in ips
@ -93,6 +96,10 @@ class MCast(object):
if not self.grp6: if not self.grp6:
ips = [x for x in ips if ":" not in x] ips = [x for x in ips if ":" not in x]
# discard non-linklocal ipv6
all_selected = ips[:]
ips = [x for x in ips if ":" not in x or x.startswith("fe80")]
if not ips: if not ips:
raise Exception("no server IP matches the mdns config") raise Exception("no server IP matches the mdns config")
@ -117,16 +124,39 @@ class MCast(object):
except: except:
pass pass
# most ipv6 clients expect multicast on linklocal ip only;
# add a/aaaa records for the other nic IPs
other_ips: set[str] = set()
if v6 and netdev not in ("?", ""):
for oip, onic in self.hub.tcpsrv.netdevs.items():
if (
onic.split(",")[0] == netdev
and oip in all_selected
and ":" in oip
):
other_ips.add(oip)
net = ipaddress.ip_network(ip, False) net = ipaddress.ip_network(ip, False)
ip = ip.split("/")[0] ip = ip.split("/")[0]
srv = self.Srv(sck, idx, self.grp6 if ":" in ip else self.grp4, ip, net) srv = self.Srv(
sck, idx, netdev, self.grp6 if ":" in ip else self.grp4, ip, net
)
for oth_ip in other_ips:
srv.ips[oth_ip.split("/")[0]] = ipaddress.ip_network(oth_ip, False)
# gvfs breaks if a linklocal ip appears in a dns reply
srv.ips = {k: v for k, v in srv.ips.items() if not k.startswith("fe80")}
if not srv.ips:
self.log("no routable IPs on {}; skipping [{}]".format(netdev, ip), 3)
continue
try: try:
self.setup_socket(srv) self.setup_socket(srv)
self.srv[sck] = srv self.srv[sck] = srv
bound.append(ip) bound.append(ip)
except: except:
self.log("announce failed on [{}]:\n{}".format(ip, min_ex())) t = "announce failed on {} [{}]:\n{}"
self.log(t.format(netdev, ip, min_ex()), 3)
if self.args.zm_msub: if self.args.zm_msub:
for s1 in self.srv.values(): for s1 in self.srv.values():
@ -145,18 +175,22 @@ class MCast(object):
if net1 == net2 and ip1 != ip2: if net1 == net2 and ip1 != ip2:
s1.ips[ip2] = net2 s1.ips[ip2] = net2
self.sips = set([x.ip for x in self.srv.values()]) self.sips = set([x.split("/")[0] for x in all_selected])
for srv in self.srv.values():
assert srv.ip in self.sips
return bound return bound
def setup_socket(self, srv: MC_Sck) -> None: def setup_socket(self, srv: MC_Sck) -> None:
sck = srv.sck sck = srv.sck
if srv.v6: if srv.v6:
if self.args.zmv: if self.args.zmv:
self.log("v6({}) idx({})".format(srv.ip, srv.idx), 6) self.log("v6({}) idx({}) {}".format(srv.ip, srv.idx, srv.ips), 6)
bip = socket.inet_pton(socket.AF_INET6, srv.ip) for ip in srv.ips:
self.b2srv[bip] = srv bip = socket.inet_pton(socket.AF_INET6, ip)
self.b6.append(bip) self.b2srv[bip] = srv
self.b6.append(bip)
sck.bind((self.grp6 if srv.idx else "", self.port, 0, srv.idx)) sck.bind((self.grp6 if srv.idx else "", self.port, 0, srv.idx))
bgrp = socket.inet_pton(socket.AF_INET6, self.grp6) bgrp = socket.inet_pton(socket.AF_INET6, self.grp6)

View file

@ -764,7 +764,7 @@ class NSEC(RD):
label = property(get_label, set_label) label = property(get_label, set_label)
def pack(self, buffer): def pack(self, buffer):
buffer.encode_name_nocompress(self.label) buffer.encode_name(self.label)
buffer.append(encode_type_bitmap(self.rrlist)) buffer.append(encode_type_bitmap(self.rrlist))
def __repr__(self): def __repr__(self):

View file

@ -397,7 +397,7 @@ class TcpSrv(object):
for nip in nic.ips: for nip in nic.ips:
ipa = nip.ip[0] if ":" in str(nip.ip) else nip.ip ipa = nip.ip[0] if ":" in str(nip.ip) else nip.ip
sip = "{}/{}".format(ipa, nip.network_prefix) sip = "{}/{}".format(ipa, nip.network_prefix)
if sip.startswith("fe80") or sip.startswith("169.254"): if sip.startswith("169.254"):
# browsers dont impl linklocal # browsers dont impl linklocal
continue continue

View file

@ -16,6 +16,10 @@ https://github.com/paulc/dnslib/
C: 2010-2017 Paul Chakravarti C: 2010-2017 Paul Chakravarti
L: BSD 2-Clause L: BSD 2-Clause
https://github.com/pydron/ifaddr/
C: 2014 Stefan C. Mueller
L: BSD-2-Clause
https://github.com/giampaolo/pyftpdlib/ https://github.com/giampaolo/pyftpdlib/
C: 2007 Giampaolo Rodola C: 2007 Giampaolo Rodola
L: MIT L: MIT

View file

@ -19,6 +19,7 @@ copyparty/httpconn.py,
copyparty/httpsrv.py, copyparty/httpsrv.py,
copyparty/ico.py, copyparty/ico.py,
copyparty/mdns.py, copyparty/mdns.py,
copyparty/multicast.py,
copyparty/mtag.py, copyparty/mtag.py,
copyparty/res, copyparty/res,
copyparty/res/COPYING.txt, copyparty/res/COPYING.txt,