diff --git a/copyparty/__main__.py b/copyparty/__main__.py index eedccee8..ad15f7e3 100755 --- a/copyparty/__main__.py +++ b/copyparty/__main__.py @@ -213,7 +213,7 @@ def init_E(E: EnvParams) -> None: def get_srvname() -> str: try: - ret: str = unicode(socket.gethostname()).split(".")[0].lower() + ret: str = unicode(socket.gethostname()).split(".")[0] except: ret = "" diff --git a/copyparty/httpsrv.py b/copyparty/httpsrv.py index 8b382977..f6558754 100644 --- a/copyparty/httpsrv.py +++ b/copyparty/httpsrv.py @@ -165,6 +165,10 @@ class HttpSrv(object): def listen(self, sck: socket.socket, nlisteners: int) -> None: if self.args.j != 1: # lost in the pickle; redefine + try: + sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + except: + pass sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sck.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sck.settimeout(None) # < does not inherit, ^ does diff --git a/copyparty/mdns.py b/copyparty/mdns.py index 13bda0aa..0f93c62d 100644 --- a/copyparty/mdns.py +++ b/copyparty/mdns.py @@ -68,15 +68,20 @@ class MDNS(MCast): self.ttl = 300 self.running = True - zs = self.args.name.lower() + ".local." + zs = self.args.name + ".local." zs = zs.encode("ascii", "replace").decode("ascii", "replace") - self.hn = zs.replace("?", "_") + self.hn = "-".join(x for x in zs.split("?") if x) or ( + "vault-{}".format(random.randint(1, 255)) + ) + self.lhn = self.hn.lower() # requester ip -> (response deadline, srv, body): self.q: dict[str, tuple[float, MDNS_Sck, bytes]] = {} self.rx4 = CachedSet(0.42) # 3 probes @ 250..500..750 => 500ms span self.rx6 = CachedSet(0.42) self.svcs, self.sfqdns = self.build_svcs() + self.lsvcs = {k.lower(): v for k, v in self.svcs.items()} + self.lsfqdns = set([x.lower() for x in self.sfqdns]) self.probing = 0.0 self.unsolicited: list[float] = [] # scheduled announces on all nics @@ -211,8 +216,7 @@ class MDNS(MCast): sreply.add_answer(r) if not (have4 and have6) and not self.args.zm_noneg: - have = "AAAA" if have6 else "A" - ns = NSEC(self.hn, [have, "PTR", "SRV", "TXT"]) + ns = NSEC(self.hn, ["AAAA" if have6 else "A"]) r = RR(self.hn, QTYPE.NSEC, DC.F_IN, 120, ns) areply.add_ar(r) if len(sreply.pack()) < 1400: @@ -294,7 +298,8 @@ class MDNS(MCast): continue if self.probing < time.time(): - self.log("probe ok; starting announcements", 2) + t = "probe ok; announcing [{}]" + self.log(t.format(self.hn[:-1]), 2) self.probing = 0 def stop(self, panic=False) -> None: @@ -331,7 +336,7 @@ class MDNS(MCast): self.log(str(p)) # check for incoming probes for our hostname - cips = [U(x.rdata) for x in p.auth if U(x.rname).lower() == self.hn] + cips = [U(x.rdata) for x in p.auth if U(x.rname).lower() == self.lhn] if cips and self.sips.isdisjoint(cips): if not [x for x in cips if x not in ("::1", "127.0.0.1")]: # avahi broadcasting 127.0.0.1-only packets @@ -350,7 +355,7 @@ class MDNS(MCast): cips = [ U(x.rdata) for x in p.rr - if U(x.rname).lower() == self.hn and x.rclass == DC.F_IN + if U(x.rname).lower() == self.lhn and x.rclass == DC.F_IN ] if cips and self.sips.isdisjoint(cips): if not [x for x in cips if x not in ("::1", "127.0.0.1")]: @@ -370,7 +375,7 @@ class MDNS(MCast): # then a/aaaa records for r in p.questions: - if U(r.qname).lower() != self.hn: + if U(r.qname).lower() != self.lhn: continue # gvfs keeps repeating itself @@ -378,7 +383,7 @@ class MDNS(MCast): unicast = False for r in p.rr: rname = U(r.rname).lower() - if rname == self.hn: + if rname == self.lhn: if r.ttl > 60: found = True if r.rclass == DC.F_IN: @@ -396,7 +401,7 @@ class MDNS(MCast): # and service queries for r in p.questions: qname = U(r.qname).lower() - if qname in self.svcs or qname == "_services._dns-sd._udp.local.": + if qname in self.lsvcs or qname == "_services._dns-sd._udp.local.": self.q[cip] = (deadline, srv, srv.bp_svc) break # heed rfc-7.1 if there was an announce in the past 12sec @@ -405,7 +410,7 @@ class MDNS(MCast): if now < srv.last_tx + 12: for r in p.rr: rdata = U(r.rdata).lower() - if rdata in self.sfqdns: + if rdata in self.lsfqdns: if r.ttl > 2250: self.q.pop(cip, None) break diff --git a/copyparty/multicast.py b/copyparty/multicast.py index 092cc3b4..4acb5954 100644 --- a/copyparty/multicast.py +++ b/copyparty/multicast.py @@ -7,7 +7,7 @@ import time import ipaddress from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network -from .__init__ import TYPE_CHECKING +from .__init__ import TYPE_CHECKING, MACOS from .util import min_ex, spack if TYPE_CHECKING: @@ -72,7 +72,6 @@ class MCast(object): def create_servers(self) -> list[str]: bound: list[str] = [] ips = [x[0] for x in self.hub.tcpsrv.bound] - ips = list(set(ips)) if "::" in ips: ips = [x for x in ips if x != "::"] + list( @@ -99,8 +98,10 @@ class MCast(object): if not self.grp6: ips = [x for x in ips if ":" not in x] - # discard non-linklocal ipv6 + ips = list(set(ips)) all_selected = ips[:] + + # discard non-linklocal ipv6 ips = [x for x in ips if ":" not in x or x.startswith("fe80")] if not ips: @@ -196,19 +197,22 @@ class MCast(object): self.b2srv[bip] = srv self.b6.append(bip) - sck.bind((self.grp6 if srv.idx else "", self.port, 0, srv.idx)) + grp = self.grp6 if srv.idx and not MACOS else "" + sck.bind((grp, self.port, 0, srv.idx)) + bgrp = socket.inet_pton(socket.AF_INET6, self.grp6) dev = spack(b"@I", srv.idx) srv.mreq = bgrp + dev if srv.idx != socket.INADDR_ANY: sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, dev) - self.hop(srv) try: - sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, 1) sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255) + sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, 1) except: - pass # macos + # macos + t = "failed to set IPv6 TTL/LOOP; announcements may not survive multiple switches/routers" + self.log(t, 3) else: if self.args.zmv: self.log("v4({}) idx({})".format(srv.ip, srv.idx), 6) @@ -217,7 +221,8 @@ class MCast(object): self.b2srv[bip] = srv self.b4.append(bip) - sck.bind((self.grp4 if srv.idx else "", self.port)) + grp = self.grp4 if srv.idx and not MACOS else "" + sck.bind((grp, self.port)) bgrp = socket.inet_aton(self.grp4) dev = ( spack(b"=I", socket.INADDR_ANY) @@ -228,13 +233,15 @@ class MCast(object): if srv.idx != socket.INADDR_ANY: sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, dev) - self.hop(srv) try: - sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 1) sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 255) + sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 1) except: - pass + # probably can't happen but dontcare if it does + t = "failed to set IPv4 TTL/LOOP; announcements may not survive multiple switches/routers" + self.log(t, 3) + self.hop(srv) self.b4.sort(reverse=True) self.b6.sort(reverse=True) diff --git a/copyparty/tcpsrv.py b/copyparty/tcpsrv.py index 9fae7430..ac62ebbd 100644 --- a/copyparty/tcpsrv.py +++ b/copyparty/tcpsrv.py @@ -192,6 +192,10 @@ class TcpSrv(object): def _listen(self, ip: str, port: int) -> None: ipv = socket.AF_INET6 if ":" in ip else socket.AF_INET srv = socket.socket(ipv, socket.SOCK_STREAM) + try: + srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + except: + pass srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) srv.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) srv.settimeout(None) # < does not inherit, ^ does