diff --git a/copyparty/__init__.py b/copyparty/__init__.py index 04cffaf5..b5c50ad4 100644 --- a/copyparty/__init__.py +++ b/copyparty/__init__.py @@ -12,14 +12,14 @@ except: TYPE_CHECKING = False if True: - from typing import Any + from typing import Any, Callable PY2 = sys.version_info < (3,) -if PY2: +if not PY2: + unicode: Callable[[str], str] = str +else: sys.dont_write_bytecode = True unicode = unicode # noqa: F821 # pylint: disable=undefined-variable,self-assigning-variable -else: - unicode = str WINDOWS: Any = ( [int(x) for x in platform.version().split(".")] diff --git a/copyparty/__main__.py b/copyparty/__main__.py index 7547464f..4672a10f 100755 --- a/copyparty/__main__.py +++ b/copyparty/__main__.py @@ -9,11 +9,13 @@ __license__ = "MIT" __url__ = "https://github.com/9001/copyparty/" import argparse +import base64 import filecmp import locale import os import re import shutil +import socket import sys import threading import time @@ -209,6 +211,31 @@ def init_E(E: EnvParams) -> None: raise +def get_srvname() -> str: + try: + ret: str = unicode(socket.gethostname()).split(".")[0].lower() + except: + ret = "" + + if ret not in ["", "localhost"]: + return ret + + fp = os.path.join(E.cfg, "name.txt") + lprint("using hostname from {}\n".format(fp)) + try: + with open(fp, "rb") as f: + ret = f.read().decode("utf-8", "replace").strip() + except: + ret = "" + while len(ret) < 7: + ret += base64.b32encode(os.urandom(4))[:7].decode("utf-8").lower() + ret = re.sub("[234567=]", "", ret)[:7] + with open(fp, "wb") as f: + f.write(ret.encode("utf-8") + b"\n") + + return ret + + def ensure_locale() -> None: for x in [ "en_US.UTF-8", @@ -431,6 +458,8 @@ def run_argparse( tty = os.environ.get("TERM", "").lower() == "linux" + srvname = get_srvname() + sects = [ [ "accounts", @@ -584,6 +613,7 @@ def run_argparse( ap2.add_argument("-mcr", metavar="SEC", type=int, default=60, help="md-editor mod-chk rate") ap2.add_argument("--urlform", metavar="MODE", type=u, default="print,get", help="how to handle url-form POSTs; see --help-urlform") ap2.add_argument("--wintitle", metavar="TXT", type=u, default="cpp @ $pub", help="window title, for example [\033[32m$ip-10.1.2.\033[0m] or [\033[32m$ip-]") + ap2.add_argument("--name", metavar="TXT", type=str, default=srvname, help="server name (displayed topleft in browser and in mDNS)") ap2.add_argument("--license", action="store_true", help="show licenses and exit") ap2.add_argument("--version", action="store_true", help="show versions and exit") @@ -630,6 +660,21 @@ def run_argparse( ap2.add_argument("--ssl-dbg", action="store_true", help="dump some tls info") ap2.add_argument("--ssl-log", metavar="PATH", type=u, help="log master secrets for later decryption in wireshark") + ap2 = ap.add_argument_group("Zeroconf options") + ap2.add_argument("--zm", action="store_true", help="announce the enabled protocols over mDNS (multicast DNS-SD) -- compatible with KDE, gnome, macOS, ...") + ap2.add_argument("--zm4", action="store_true", help="IPv4 only -- try this if some clients don't work") + ap2.add_argument("--zm6", action="store_true", help="IPv6 only") + ap2.add_argument("--zmv", action="store_true", help="verbose mdns") + ap2.add_argument("--zmvv", action="store_true", help="verboser mdns") + ap2.add_argument("--zms", metavar="dhf", type=str, default="", help="list of services to announce -- d=webdav h=http f=ftp s=smb -- lowercase=plaintext uppercase=TLS -- default: all enabled services except http/https (\033[32mDdfs\033[0m if \033[33m--ftp\033[0m and \033[33m--smb\033[0m is set)") + ap2.add_argument("--zm-ld", metavar="PATH", type=str, default="", help="link a specific folder for webdav shares") + ap2.add_argument("--zm-lh", metavar="PATH", type=str, default="", help="link a specific folder for http shares") + ap2.add_argument("--zm-lf", metavar="PATH", type=str, default="", help="link a specific folder for ftp shares") + ap2.add_argument("--zm-ls", metavar="PATH", type=str, default="", help="link a specific folder for smb shares") + ap2.add_argument("--zm-mnic", action="store_true", help="merge NICs which share subnets; assume that same subnet means same network") + ap2.add_argument("--zm-msub", action="store_true", help="merge subnets on each NIC -- always enabled for ipv6 -- reduces network load, but gnome-gvfs clients may stop working") + ap2.add_argument("--mc-hop", metavar="SEC", type=int, default=0, help="rejoin multicast groups every SEC seconds (workaround for some switches/routers which cause mDNS to suddenly stop working after some time); try [\033[32m300\033[0m] or [\033[32m180\033[0m]") + ap2 = ap.add_argument_group('FTP options') ap2.add_argument("--ftp", metavar="PORT", type=int, help="enable FTP server on PORT, for example \033[32m3921") ap2.add_argument("--ftps", metavar="PORT", type=int, help="enable FTPS server on PORT, for example \033[32m3990") @@ -898,6 +943,7 @@ def main(argv: Optional[list[str]] = None) -> None: for fmtr in [RiceFormatter, RiceFormatter, Dodge11874, BasicDodge11874]: try: al = run_argparse(argv, fmtr, retry, nc) + break except SystemExit: raise except: diff --git a/copyparty/httpcli.py b/copyparty/httpcli.py index b3af9a12..fa49955c 100644 --- a/copyparty/httpcli.py +++ b/copyparty/httpcli.py @@ -11,7 +11,6 @@ import itertools import json import os import re -import socket import stat import string import threading # typechk @@ -27,11 +26,6 @@ try: except: pass -try: - from ipaddress import IPv6Address -except: - pass - from .__init__ import ANYWIN, PY2, TYPE_CHECKING, EnvParams, unicode from .authsrv import VFS # typechk from .bos import bos @@ -3030,7 +3024,7 @@ class HttpCli(object): try: if not self.args.nih: - srv_info.append(unicode(socket.gethostname()).split(".")[0]) + srv_info.append(self.args.name) except: self.log("#wow #whoa") diff --git a/copyparty/httpsrv.py b/copyparty/httpsrv.py index ffb8c7d6..8b382977 100644 --- a/copyparty/httpsrv.py +++ b/copyparty/httpsrv.py @@ -11,11 +11,6 @@ import time import queue -try: - from ipaddress import IPv6Address -except: - pass - try: import jinja2 except ImportError: diff --git a/copyparty/mdns.py b/copyparty/mdns.py new file mode 100644 index 00000000..d9c3663c --- /dev/null +++ b/copyparty/mdns.py @@ -0,0 +1,414 @@ +# coding: utf-8 +from __future__ import print_function, unicode_literals + +import random +import select +import socket +import time +from ipaddress import IPv4Network, IPv6Network + +from .__init__ import TYPE_CHECKING +from .__init__ import unicode as U +from .util import CachedSet, Daemon, min_ex +from .multicast import MC_Sck, MCast +from .stolen.dnslib import ( + RR, + DNSHeader, + DNSRecord, + DNSQuestion, + QTYPE, + A, + AAAA, + SRV, + PTR, + TXT, +) +from .stolen.dnslib import CLASS as DC + +if TYPE_CHECKING: + from .svchub import SvcHub + +if True: # pylint: disable=using-constant-test + from typing import Any, Optional, Union + + +MDNS4 = "224.0.0.251" +MDNS6 = "ff02::fb" + + +class MDNS_Sck(MC_Sck): + def __init__( + self, + sck: socket.socket, + idx: int, + grp: str, + ip: str, + net: Union[IPv4Network, IPv6Network], + ): + super(MDNS_Sck, self).__init__(sck, idx, grp, ip, net) + + self.bp_probe = b"" + self.bp_ip = b"" + self.bp_svc = b"" + self.bp_bye = b"" + + self.last_tx = 0.0 + + +class MDNS(MCast): + def __init__(self, hub: "SvcHub") -> None: + grp4 = "" if hub.args.zm6 else MDNS4 + grp6 = "" if hub.args.zm4 else MDNS6 + super(MDNS, self).__init__(hub, MDNS_Sck, grp4, grp6, 5353) + self.srv: dict[socket.socket, MDNS_Sck] = {} + + self.ttl = 300 + self.running = True + + zs = self.args.name.lower() + ".local." + zs = zs.encode("ascii", "replace").decode("ascii", "replace") + self.hn = zs.replace("?", "_") + + # requester ip -> (response deadline, srv, body): + self.q: dict[str, tuple[float, MDNS_Sck, bytes]] = {} + self.rx4 = CachedSet(0.42) # 3 probes @ 250..500..750 => 500ms span + self.rx6 = CachedSet(0.42) + self.svcs, self.sfqdns = self.build_svcs() + + self.probing = 0.0 + self.unsolicited: list[float] = [] # scheduled announces on all nics + self.defend: dict[MDNS_Sck, float] = {} # server -> deadline + + def log(self, msg: str, c: Union[int, str] = 0) -> None: + self.log_func("mDNS", msg, c) + + def build_svcs(self) -> tuple[dict[str, dict[str, Any]], set[str]]: + zms = self.args.zms + http = {"port": 80 if 80 in self.args.p else self.args.p[0]} + https = {"port": 443 if 443 in self.args.p else self.args.p[0]} + webdav = http.copy() + webdavs = https.copy() + webdav["u"] = webdavs["u"] = "u" # KDE requires username + ftp = {"port": (self.args.ftp if "f" in zms else self.args.ftps)} + smb = {"port": self.args.smb_port} + + # some gvfs require path + zs = self.args.zm_ld or "/" + if zs: + webdav["path"] = zs + webdavs["path"] = zs + + if self.args.zm_lh: + http["path"] = self.args.zm_lh + https["path"] = self.args.zm_lh + + if self.args.zm_lf: + ftp["path"] = self.args.zm_lf + + if self.args.zm_ls: + smb["path"] = self.args.zm_ls + + svcs: dict[str, dict[str, Any]] = {} + + if "d" in zms: + svcs["_webdav._tcp.local."] = webdav + + if "D" in zms: + svcs["_webdavs._tcp.local."] = webdavs + + if "h" in zms: + svcs["_http._tcp.local."] = http + + if "H" in zms: + svcs["_https._tcp.local."] = https + + if "f" in zms.lower(): + svcs["_ftp._tcp.local."] = ftp + + if "s" in zms.lower(): + svcs["_smb._tcp.local."] = smb + + sfqdns: set[str] = set() + for k, v in svcs.items(): + name = "{}-c-{}".format(self.args.name, k.split(".")[0][1:]) + v["name"] = name + sfqdns.add("{}.{}".format(name, k)) + + return svcs, sfqdns + + def build_replies(self) -> None: + for srv in self.srv.values(): + probe = DNSRecord(DNSHeader(0, 0), q=DNSQuestion(self.hn, QTYPE.ANY)) + areply = DNSRecord(DNSHeader(0, 0x8400)) + sreply = DNSRecord(DNSHeader(0, 0x8400)) + bye = DNSRecord(DNSHeader(0, 0x8400)) + + for ip in srv.ips: + if ":" in ip: + qt = QTYPE.AAAA + ar = {"rclass": DC.F_IN, "rdata": AAAA(ip)} + else: + qt = QTYPE.A + ar = {"rclass": DC.F_IN, "rdata": A(ip)} + + r0 = RR(self.hn, qt, ttl=0, **ar) + r120 = RR(self.hn, qt, ttl=120, **ar) + # rfc-10: + # SHOULD rr ttl 120sec for A/AAAA/SRV + # (and recommend 75min for all others) + + probe.add_auth(r120) + areply.add_answer(r120) + sreply.add_answer(r120) + bye.add_answer(r0) + + for sclass, props in self.svcs.items(): + sname = props["name"] + sport = props["port"] + sfqdn = sname + "." + sclass + + k = "_services._dns-sd._udp.local." + r = RR(k, QTYPE.PTR, DC.IN, 4500, PTR(sclass)) + sreply.add_answer(r) + + r = RR(sclass, QTYPE.PTR, DC.IN, 4500, PTR(sfqdn)) + sreply.add_answer(r) + + r = RR(sfqdn, QTYPE.SRV, DC.F_IN, 120, SRV(0, 0, sport, self.hn)) + sreply.add_answer(r) + areply.add_answer(r) + + r = RR(sfqdn, QTYPE.SRV, DC.F_IN, 0, SRV(0, 0, sport, self.hn)) + bye.add_answer(r) + + txts = [] + for k in ("u", "path"): + if k not in props: + continue + + zb = "{}={}".format(k, props[k]).encode("utf-8") + if len(zb) > 255: + t = "value too long for mdns: [{}]" + raise Exception(t.format(props[k])) + + txts.append(zb) + + # gvfs really wants txt even if they're empty + r = RR(sfqdn, QTYPE.TXT, DC.F_IN, 4500, TXT(txts)) + sreply.add_answer(r) + + srv.bp_probe = probe.pack() + srv.bp_ip = areply.pack() + srv.bp_svc = sreply.pack() + srv.bp_bye = bye.pack() + + # since all replies are small enough to fit in one packet, + # always send full replies rather than just a/aaaa records + srv.bp_ip = srv.bp_svc + + def send_probes(self) -> None: + slp = random.random() * 0.25 + for _ in range(3): + time.sleep(slp) + slp = 0.25 + if not self.running: + break + + if self.args.zmv: + self.log("sending hostname probe...") + + # ipv4: need to probe each ip (each server) + # ipv6: only need to probe each set of looped nics + probed6: set[str] = set() + for srv in self.srv.values(): + if srv.ip in probed6: + continue + + try: + srv.sck.sendto(srv.bp_probe, (srv.grp, 5353)) + if srv.v6: + for ip in srv.ips: + probed6.add(ip) + except Exception as ex: + self.log("sendto failed: {} ({})".format(srv.ip, ex), "90") + + def run(self) -> None: + bound = self.create_servers() + if not bound: + self.log("failed to announce copyparty services on the network", 3) + return + + self.build_replies() + Daemon(self.send_probes) + zf = time.time() + 2 + self.probing = zf # cant unicast so give everyone an extra sec + self.unsolicited = [zf, zf + 1, zf + 3, zf + 7] # rfc-8.3 + last_hop = time.time() + ihop = self.args.mc_hop + while self.running: + timeout = ( + 0.02 + random.random() * 0.07 + if self.probing or self.q or self.defend or self.unsolicited + else (last_hop + ihop if ihop else 180) + ) + rdy = select.select(self.srv, [], [], timeout) + rx: list[socket.socket] = rdy[0] # type: ignore + self.rx4.cln() + self.rx6.cln() + for srv in rx: + buf, addr = srv.recvfrom(4096) + try: + self.eat(buf, addr) + except: + t = "{} \033[33m|{}| {}\n{}".format( + addr, len(buf), repr(buf)[2:-1], min_ex() + ) + self.log(t, 6) + + if not self.probing: + self.process() + continue + + if self.probing < time.time(): + self.log("probe ok; starting announcements", 2) + self.probing = 0 + + def stop(self, panic=False) -> None: + self.running = False + if not panic: + for srv in self.srv.values(): + srv.sck.sendto(srv.bp_bye, (srv.grp, 5353)) + + def eat(self, buf: bytes, addr: tuple[str, int]): + cip = addr[0] + if cip.startswith("fe80") or cip.startswith("169.254"): + return + + v6 = ":" in cip + cache = self.rx6 if v6 else self.rx4 + if buf in cache.c: + return + + cache.add(buf) + srv: Optional[MDNS_Sck] = self.map_client(cip) # type: ignore + if not srv: + return + + now = time.time() + + if self.args.zmv: + self.log("[{}] \033[36m{} \033[0m|{}|".format(srv.ip, cip, len(buf)), "90") + + p = DNSRecord.parse(buf) + if self.args.zmvv: + self.log(str(p)) + + # check for incoming probes for our hostname + cips = [U(x.rdata) for x in p.auth if U(x.rname).lower() == self.hn] + if cips and self.sips.isdisjoint(cips): + if not [x for x in cips if x not in ("::1", "127.0.0.1")]: + # avahi broadcasting 127.0.0.1-only packets + return + + self.log("someone trying to steal our hostname: {}".format(cips), 3) + # immediately unicast + if not self.probing: + srv.sck.sendto(srv.bp_ip, (cip, 5353)) + + # and schedule multicast + self.defend[srv] = self.defend.get(srv, now + 0.1) + return + + # check for someone rejecting our probe / hijacking our hostname + cips = [ + U(x.rdata) + for x in p.rr + if U(x.rname).lower() == self.hn and x.rclass == DC.F_IN + ] + if cips and self.sips.isdisjoint(cips): + if not [x for x in cips if x not in ("::1", "127.0.0.1")]: + # avahi broadcasting 127.0.0.1-only packets + return + + t = "mdns zeroconf: " + if self.probing: + t += "Cannot start; hostname '{}' is occupied" + else: + t += "Emergency stop; hostname '{}' got stolen" + + t += "! Use --name to set another hostname.\n\nName taken by {}\n\nYour IPs: {}\n" + self.log(t.format(self.args.name, cips, list(self.sips)), 1) + self.stop(True) + return + + # then a/aaaa records + for r in p.questions: + if U(r.qname).lower() != self.hn: + continue + + # gvfs keeps repeating itself + found = False + for r in p.rr: + rname = U(r.rname).lower() + if rname == self.hn and r.ttl > 60: + found = True + break + + if not found: + self.q[cip] = (0, srv, srv.bp_ip) + return + + deadline = now + (0.5 if p.header.tc else 0.02) # rfc-7.2 + + # and service queries + for r in p.questions: + qname = U(r.qname).lower() + if qname in self.svcs or qname == "_services._dns-sd._udp.local.": + self.q[cip] = (deadline, srv, srv.bp_svc) + break + # heed rfc-7.1 if there was an announce in the past 12sec + # (workaround gvfs race-condition where it occasionally + # doesn't read/decode the full response...) + if now < srv.last_tx + 12: + for r in p.rr: + rdata = U(r.rdata).lower() + if rdata in self.sfqdns: + if r.ttl > 2250: + self.q.pop(cip, None) + break + + def process(self) -> None: + tx = set() + now = time.time() + cooldown = 0.9 # rfc-6: 1 + if self.unsolicited and self.unsolicited[0] < now: + self.unsolicited.pop(0) + cooldown = 0.1 + for srv in self.srv.values(): + tx.add(srv) + + for srv, deadline in list(self.defend.items()): + if now < deadline: + continue + + if self._tx(srv, srv.bp_ip, 0.02): # rfc-6: 0.25 + self.defend.pop(srv) + + for cip, (deadline, srv, msg) in list(self.q.items()): + if now < deadline: + continue + + self.q.pop(cip) + self._tx(srv, msg, cooldown) + + for srv in tx: + self._tx(srv, srv.bp_svc, cooldown) + + def _tx(self, srv: MDNS_Sck, msg: bytes, cooldown: float) -> bool: + now = time.time() + if now < srv.last_tx + cooldown: + return False + + srv.sck.sendto(msg, (srv.grp, 5353)) + srv.last_tx = now + return True diff --git a/copyparty/multicast.py b/copyparty/multicast.py new file mode 100644 index 00000000..b63b8974 --- /dev/null +++ b/copyparty/multicast.py @@ -0,0 +1,252 @@ +# coding: utf-8 +from __future__ import print_function, unicode_literals + +import socket +import time +import ipaddress +from ipaddress import IPv4Network, IPv6Network, IPv4Address, IPv6Address + +from .__init__ import TYPE_CHECKING +from .util import min_ex, spack + +if TYPE_CHECKING: + from .svchub import SvcHub + +if True: # pylint: disable=using-constant-test + from typing import Optional, Union + +if not hasattr(socket, "IPPROTO_IPV6"): + setattr(socket, "IPPROTO_IPV6", 41) + + +class MC_Sck(object): + """there is one socket for each server ip""" + + def __init__( + self, + sck: socket.socket, + idx: int, + grp: str, + ip: str, + net: Union[IPv4Network, IPv6Network], + ): + self.sck = sck + self.idx = idx + self.grp = grp + self.mreq = b"" + self.ip = ip + self.net = net + self.ips = {ip: net} + self.v6 = ":" in ip + + +class MCast(object): + def __init__( + self, hub: "SvcHub", Srv: type[MC_Sck], mc_grp_4: str, mc_grp_6: str, port: int + ) -> None: + """disable ipv%d by setting mc_grp_%d empty""" + self.hub = hub + self.Srv = Srv + self.args = hub.args + self.asrv = hub.asrv + self.log_func = hub.log + self.grp4 = mc_grp_4 + self.grp6 = mc_grp_6 + self.port = port + + self.srv: dict[socket.socket, MC_Sck] = {} # listening sockets + self.sips: set[str] = set() # all listening ips + self.b2srv: dict[bytes, MC_Sck] = {} # binary-ip -> server socket + self.b4: list[bytes] = [] # sorted list of binary-ips + self.b6: list[bytes] = [] # sorted list of binary-ips + self.cscache: dict[str, Optional[MC_Sck]] = {} # client ip -> server cache + + def log(self, msg: str, c: Union[int, str] = 0) -> None: + self.log_func("multicast", msg, c) + + def create_servers(self) -> list[str]: + bound: list[str] = [] + ips = [x[0] for x in self.hub.tcpsrv.bound] + ips = list(set(ips)) + + if "::" in ips: + ips = [x for x in ips if x != "::"] + list( + [x.split("/")[0] for x in self.hub.tcpsrv.netdevs if ":" in x] + ) + ips.append("0.0.0.0") + + if "0.0.0.0" in ips: + ips = [x for x in ips if x != "0.0.0.0"] + list( + [x.split("/")[0] for x in self.hub.tcpsrv.netdevs if ":" not in x] + ) + + ips = [x for x in ips if x not in ("::1", "127.0.0.1")] + + ips = [ + [x for x in self.hub.tcpsrv.netdevs if x.startswith(y + "/")][0] + for y in ips + ] + + if not self.grp4: + ips = [x for x in ips if ":" in x] + + if not self.grp6: + ips = [x for x in ips if ":" not in x] + + if not ips: + raise Exception("no server IP matches the mdns config") + + for ip in ips: + v6 = ":" in ip + netdev = "?" + try: + netdev = self.hub.tcpsrv.netdevs[ip].split(",")[0] + idx = socket.if_nametoindex(netdev) + except: + idx = socket.INADDR_ANY + t = "using INADDR_ANY for ip [{}], netdev [{}]" + if not self.srv and ip not in ["::", "0.0.0.0"]: + self.log(t.format(ip, netdev), 3) + + ipv = socket.AF_INET6 if v6 else socket.AF_INET + sck = socket.socket(ipv, socket.SOCK_DGRAM, socket.IPPROTO_UDP) + sck.settimeout(None) + sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + except: + pass + + net = ipaddress.ip_network(ip, False) + ip = ip.split("/")[0] + srv = self.Srv(sck, idx, self.grp6 if ":" in ip else self.grp4, ip, net) + + try: + self.setup_socket(srv) + self.srv[sck] = srv + bound.append(ip) + except: + self.log("announce failed on [{}]:\n{}".format(ip, min_ex())) + + if self.args.zm_msub: + for s1 in self.srv.values(): + for s2 in self.srv.values(): + if s1.idx != s2.idx: + continue + + if s1.ip not in s2.ips: + s2.ips[s1.ip] = s1.net + + if self.args.zm_mnic: + for s1 in self.srv.values(): + for s2 in self.srv.values(): + for ip1, net1 in list(s1.ips.items()): + for ip2, net2 in list(s2.ips.items()): + if net1 == net2 and ip1 != ip2: + s1.ips[ip2] = net2 + + self.sips = set([x.ip for x in self.srv.values()]) + return bound + + def setup_socket(self, srv: MC_Sck) -> None: + sck = srv.sck + if srv.v6: + if self.args.zmv: + self.log("v6({}) idx({})".format(srv.ip, srv.idx), 6) + + bip = socket.inet_pton(socket.AF_INET6, srv.ip) + self.b2srv[bip] = srv + self.b6.append(bip) + + sck.bind((self.grp6 if srv.idx else "", self.port, 0, srv.idx)) + bgrp = socket.inet_pton(socket.AF_INET6, self.grp6) + dev = spack(b"@I", srv.idx) + srv.mreq = bgrp + dev + if srv.idx != socket.INADDR_ANY: + sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, dev) + + self.hop(srv) + try: + sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, 1) + sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255) + except: + pass # macos + else: + if self.args.zmv: + self.log("v4({}) idx({})".format(srv.ip, srv.idx), 6) + + bip = socket.inet_aton(srv.ip) + self.b2srv[bip] = srv + self.b4.append(bip) + + sck.bind((self.grp4 if srv.idx else "", self.port)) + bgrp = socket.inet_aton(self.grp4) + dev = ( + spack(b"=I", socket.INADDR_ANY) + if srv.idx == socket.INADDR_ANY + else socket.inet_aton(srv.ip) + ) + srv.mreq = bgrp + dev + if srv.idx != socket.INADDR_ANY: + sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, dev) + + self.hop(srv) + try: + sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 1) + sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 255) + except: + pass + + self.b4.sort(reverse=True) + self.b6.sort(reverse=True) + + def hop(self, srv: MC_Sck) -> None: + """rejoin to keepalive on routers/switches without igmp-snooping""" + sck = srv.sck + req = srv.mreq + if ":" in srv.ip: + try: + sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_LEAVE_GROUP, req) + # linux does leaves/joins twice with 0.2~1.05s spacing + time.sleep(1.2) + except: + pass + + sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, req) + else: + try: + sck.setsockopt(socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, req) + time.sleep(1.2) + except: + pass + + sck.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, req) + + def map_client(self, cip: str) -> Optional[MC_Sck]: + try: + return self.cscache[cip] + except: + pass + + ret: Optional[MC_Sck] = None + v6 = ":" in cip + ci = IPv6Address(cip) if v6 else IPv4Address(cip) + for x in self.b6 if v6 else self.b4: + srv = self.b2srv[x] + if any([x for x in srv.ips.values() if ci in x]): + ret = srv + break + + if not ret and cip in ("127.0.0.1", "::1"): + # just give it something + ret = list(self.srv.values())[0] + + if not ret: + t = "could not map client {} to known subnet; maybe forwarded from another network?" + self.log(t.format(cip), 3) + + self.cscache[cip] = ret + if len(self.cscache) > 9000: + self.cscache = {} + + return ret diff --git a/copyparty/stolen/dnslib/README.md b/copyparty/stolen/dnslib/README.md new file mode 100644 index 00000000..40f6cebf --- /dev/null +++ b/copyparty/stolen/dnslib/README.md @@ -0,0 +1,5 @@ +`dnslib` but heavily simplified/feature-stripped + +L: MIT +Copyright (c) 2010 - 2017 Paul Chakravarti +https://github.com/paulc/dnslib/ diff --git a/copyparty/stolen/dnslib/__init__.py b/copyparty/stolen/dnslib/__init__.py new file mode 100644 index 00000000..831a2be2 --- /dev/null +++ b/copyparty/stolen/dnslib/__init__.py @@ -0,0 +1,11 @@ +# coding: utf-8 + +""" +L: MIT +Copyright (c) 2010 - 2017 Paul Chakravarti +https://github.com/paulc/dnslib/tree/0.9.23 +""" + +from .dns import * + +version = "0.9.23" diff --git a/copyparty/stolen/dnslib/bimap.py b/copyparty/stolen/dnslib/bimap.py new file mode 100644 index 00000000..cf2b6c94 --- /dev/null +++ b/copyparty/stolen/dnslib/bimap.py @@ -0,0 +1,41 @@ +# coding: utf-8 + +import types + + +class BimapError(Exception): + pass + + +class Bimap(object): + def __init__(self, name, forward, error=AttributeError): + self.name = name + self.error = error + self.forward = forward.copy() + self.reverse = dict([(v, k) for (k, v) in list(forward.items())]) + + def get(self, k, default=None): + try: + return self.forward[k] + except KeyError: + return default or str(k) + + def __getitem__(self, k): + try: + return self.forward[k] + except KeyError: + if isinstance(self.error, types.FunctionType): + return self.error(self.name, k, True) + else: + raise self.error("%s: Invalid forward lookup: [%s]" % (self.name, k)) + + def __getattr__(self, k): + try: + if k == "__wrapped__": + raise AttributeError() + return self.reverse[k] + except KeyError: + if isinstance(self.error, types.FunctionType): + return self.error(self.name, k, False) + else: + raise self.error("%s: Invalid reverse lookup: [%s]" % (self.name, k)) diff --git a/copyparty/stolen/dnslib/bit.py b/copyparty/stolen/dnslib/bit.py new file mode 100644 index 00000000..b12af1a4 --- /dev/null +++ b/copyparty/stolen/dnslib/bit.py @@ -0,0 +1,15 @@ +# coding: utf-8 + +from __future__ import print_function + + +def get_bits(data, offset, bits=1): + mask = ((1 << bits) - 1) << offset + return (data & mask) >> offset + + +def set_bits(data, value, offset, bits=1): + mask = ((1 << bits) - 1) << offset + clear = 0xFFFF ^ mask + data = (data & clear) | ((value << offset) & mask) + return data diff --git a/copyparty/stolen/dnslib/buffer.py b/copyparty/stolen/dnslib/buffer.py new file mode 100644 index 00000000..1fcda2c0 --- /dev/null +++ b/copyparty/stolen/dnslib/buffer.py @@ -0,0 +1,56 @@ +# coding: utf-8 + +import binascii +import struct + + +class BufferError(Exception): + pass + + +class Buffer(object): + def __init__(self, data=b""): + self.data = bytearray(data) + self.offset = 0 + + def remaining(self): + return len(self.data) - self.offset + + def get(self, length): + if length > self.remaining(): + raise BufferError( + "Not enough bytes [offset=%d,remaining=%d,requested=%d]" + % (self.offset, self.remaining(), length) + ) + start = self.offset + end = self.offset + length + self.offset += length + return bytes(self.data[start:end]) + + def hex(self): + return binascii.hexlify(self.data) + + def pack(self, fmt, *args): + self.offset += struct.calcsize(fmt) + self.data += struct.pack(fmt, *args) + + def append(self, s): + self.offset += len(s) + self.data += s + + def update(self, ptr, fmt, *args): + s = struct.pack(fmt, *args) + self.data[ptr : ptr + len(s)] = s + + def unpack(self, fmt): + try: + data = self.get(struct.calcsize(fmt)) + return struct.unpack(fmt, data) + except struct.error: + raise BufferError( + "Error unpacking struct '%s' <%s>" + % (fmt, binascii.hexlify(data).decode()) + ) + + def __len__(self): + return len(self.data) diff --git a/copyparty/stolen/dnslib/dns.py b/copyparty/stolen/dnslib/dns.py new file mode 100644 index 00000000..66edd1c8 --- /dev/null +++ b/copyparty/stolen/dnslib/dns.py @@ -0,0 +1,781 @@ +# coding: utf-8 + +from __future__ import print_function + +import binascii +import random + +from itertools import chain + +from .bit import get_bits, set_bits +from .bimap import Bimap, BimapError +from .buffer import BufferError +from .label import DNSLabel, DNSBuffer +from .ranges import H, I, IP4, IP6, check_bytes + + +class DNSError(Exception): + pass + + +def unknown_qtype(name, key, forward): + if forward: + try: + return "TYPE%d" % (key,) + except: + raise DNSError("%s: Invalid forward lookup: [%s]" % (name, key)) + else: + if key.startswith("TYPE"): + try: + return int(key[4:]) + except: + pass + raise DNSError("%s: Invalid reverse lookup: [%s]" % (name, key)) + + +QTYPE = Bimap( + "QTYPE", + {1: "A", 12: "PTR", 16: "TXT", 28: "AAAA", 33: "SRV", 47: "NSEC", 255: "ANY"}, + unknown_qtype, +) + +CLASS = Bimap("CLASS", {1: "IN", 254: "None", 255: "*"}, DNSError) + +QR = Bimap("QR", {0: "QUERY", 1: "RESPONSE"}, DNSError) + +RCODE = Bimap( + "RCODE", + { + 0: "NOERROR", + 1: "FORMERR", + 2: "SERVFAIL", + 3: "NXDOMAIN", + 4: "NOTIMP", + 5: "REFUSED", + 6: "YXDOMAIN", + 7: "YXRRSET", + 8: "NXRRSET", + 9: "NOTAUTH", + 10: "NOTZONE", + }, + DNSError, +) + +OPCODE = Bimap( + "OPCODE", {0: "QUERY", 1: "IQUERY", 2: "STATUS", 4: "NOTIFY", 5: "UPDATE"}, DNSError +) + + +def label(label, origin=None): + if label.endswith("."): + return DNSLabel(label) + else: + return (origin if isinstance(origin, DNSLabel) else DNSLabel(origin)).add(label) + + +class DNSRecord(object): + @classmethod + def parse(cls, packet): + buffer = DNSBuffer(packet) + try: + header = DNSHeader.parse(buffer) + questions = [] + rr = [] + auth = [] + ar = [] + for i in range(header.q): + questions.append(DNSQuestion.parse(buffer)) + for i in range(header.a): + rr.append(RR.parse(buffer)) + for i in range(header.auth): + auth.append(RR.parse(buffer)) + for i in range(header.ar): + ar.append(RR.parse(buffer)) + return cls(header, questions, rr, auth=auth, ar=ar) + except (BufferError, BimapError) as e: + raise DNSError( + "Error unpacking DNSRecord [offset=%d]: %s" % (buffer.offset, e) + ) + + @classmethod + def question(cls, qname, qtype="A", qclass="IN"): + return DNSRecord( + q=DNSQuestion(qname, getattr(QTYPE, qtype), getattr(CLASS, qclass)) + ) + + def __init__( + self, header=None, questions=None, rr=None, q=None, a=None, auth=None, ar=None + ): + self.header = header or DNSHeader() + self.questions = questions or [] + self.rr = rr or [] + self.auth = auth or [] + self.ar = ar or [] + + if q: + self.questions.append(q) + if a: + self.rr.append(a) + self.set_header_qa() + + def reply(self, ra=1, aa=1): + return DNSRecord( + DNSHeader(id=self.header.id, bitmap=self.header.bitmap, qr=1, ra=ra, aa=aa), + q=self.q, + ) + + def add_question(self, *q): + self.questions.extend(q) + self.set_header_qa() + + def add_answer(self, *rr): + self.rr.extend(rr) + self.set_header_qa() + + def add_auth(self, *auth): + self.auth.extend(auth) + self.set_header_qa() + + def add_ar(self, *ar): + self.ar.extend(ar) + self.set_header_qa() + + def set_header_qa(self): + self.header.q = len(self.questions) + self.header.a = len(self.rr) + self.header.auth = len(self.auth) + self.header.ar = len(self.ar) + + def get_q(self): + return self.questions[0] if self.questions else DNSQuestion() + + q = property(get_q) + + def get_a(self): + return self.rr[0] if self.rr else RR() + + a = property(get_a) + + def pack(self): + self.set_header_qa() + buffer = DNSBuffer() + self.header.pack(buffer) + for q in self.questions: + q.pack(buffer) + for rr in self.rr: + rr.pack(buffer) + for auth in self.auth: + auth.pack(buffer) + for ar in self.ar: + ar.pack(buffer) + return buffer.data + + def truncate(self): + return DNSRecord(DNSHeader(id=self.header.id, bitmap=self.header.bitmap, tc=1)) + + def format(self, prefix="", sort=False): + s = sorted if sort else lambda x: x + sections = [repr(self.header)] + sections.extend(s([repr(q) for q in self.questions])) + sections.extend(s([repr(rr) for rr in self.rr])) + sections.extend(s([repr(rr) for rr in self.auth])) + sections.extend(s([repr(rr) for rr in self.ar])) + return prefix + ("\n" + prefix).join(sections) + + short = format + + def __repr__(self): + return self.format() + + __str__ = __repr__ + + +class DNSHeader(object): + id = H("id") + bitmap = H("bitmap") + q = H("q") + a = H("a") + auth = H("auth") + ar = H("ar") + + @classmethod + def parse(cls, buffer): + try: + (id, bitmap, q, a, auth, ar) = buffer.unpack("!HHHHHH") + return cls(id, bitmap, q, a, auth, ar) + except (BufferError, BimapError) as e: + raise DNSError( + "Error unpacking DNSHeader [offset=%d]: %s" % (buffer.offset, e) + ) + + def __init__(self, id=None, bitmap=None, q=0, a=0, auth=0, ar=0, **args): + if id is None: + self.id = random.randint(0, 65535) + else: + self.id = id + if bitmap is None: + self.bitmap = 0 + self.rd = 1 + else: + self.bitmap = bitmap + self.q = q + self.a = a + self.auth = auth + self.ar = ar + for k, v in args.items(): + if k.lower() == "qr": + self.qr = v + elif k.lower() == "opcode": + self.opcode = v + elif k.lower() == "aa": + self.aa = v + elif k.lower() == "tc": + self.tc = v + elif k.lower() == "rd": + self.rd = v + elif k.lower() == "ra": + self.ra = v + elif k.lower() == "z": + self.z = v + elif k.lower() == "ad": + self.ad = v + elif k.lower() == "cd": + self.cd = v + elif k.lower() == "rcode": + self.rcode = v + + def get_qr(self): + return get_bits(self.bitmap, 15) + + def set_qr(self, val): + self.bitmap = set_bits(self.bitmap, val, 15) + + qr = property(get_qr, set_qr) + + def get_opcode(self): + return get_bits(self.bitmap, 11, 4) + + def set_opcode(self, val): + self.bitmap = set_bits(self.bitmap, val, 11, 4) + + opcode = property(get_opcode, set_opcode) + + def get_aa(self): + return get_bits(self.bitmap, 10) + + def set_aa(self, val): + self.bitmap = set_bits(self.bitmap, val, 10) + + aa = property(get_aa, set_aa) + + def get_tc(self): + return get_bits(self.bitmap, 9) + + def set_tc(self, val): + self.bitmap = set_bits(self.bitmap, val, 9) + + tc = property(get_tc, set_tc) + + def get_rd(self): + return get_bits(self.bitmap, 8) + + def set_rd(self, val): + self.bitmap = set_bits(self.bitmap, val, 8) + + rd = property(get_rd, set_rd) + + def get_ra(self): + return get_bits(self.bitmap, 7) + + def set_ra(self, val): + self.bitmap = set_bits(self.bitmap, val, 7) + + ra = property(get_ra, set_ra) + + def get_z(self): + return get_bits(self.bitmap, 6) + + def set_z(self, val): + self.bitmap = set_bits(self.bitmap, val, 6) + + z = property(get_z, set_z) + + def get_ad(self): + return get_bits(self.bitmap, 5) + + def set_ad(self, val): + self.bitmap = set_bits(self.bitmap, val, 5) + + ad = property(get_ad, set_ad) + + def get_cd(self): + return get_bits(self.bitmap, 4) + + def set_cd(self, val): + self.bitmap = set_bits(self.bitmap, val, 4) + + cd = property(get_cd, set_cd) + + def get_rcode(self): + return get_bits(self.bitmap, 0, 4) + + def set_rcode(self, val): + self.bitmap = set_bits(self.bitmap, val, 0, 4) + + rcode = property(get_rcode, set_rcode) + + def pack(self, buffer): + buffer.pack("!HHHHHH", self.id, self.bitmap, self.q, self.a, self.auth, self.ar) + + def __repr__(self): + f = [ + self.aa and "AA", + self.tc and "TC", + self.rd and "RD", + self.ra and "RA", + self.z and "Z", + self.ad and "AD", + self.cd and "CD", + ] + if OPCODE.get(self.opcode) == "UPDATE": + f1 = "zo" + f2 = "pr" + f3 = "up" + f4 = "ad" + else: + f1 = "q" + f2 = "a" + f3 = "ns" + f4 = "ar" + return ( + "" + % ( + self.id, + QR.get(self.qr), + OPCODE.get(self.opcode), + ",".join(filter(None, f)), + RCODE.get(self.rcode), + f1, + self.q, + f2, + self.a, + f3, + self.auth, + f4, + self.ar, + ) + ) + + __str__ = __repr__ + + +class DNSQuestion(object): + @classmethod + def parse(cls, buffer): + try: + qname = buffer.decode_name() + qtype, qclass = buffer.unpack("!HH") + return cls(qname, qtype, qclass) + except (BufferError, BimapError) as e: + raise DNSError( + "Error unpacking DNSQuestion [offset=%d]: %s" % (buffer.offset, e) + ) + + def __init__(self, qname=None, qtype=1, qclass=1): + self.qname = qname + self.qtype = qtype + self.qclass = qclass + + def set_qname(self, qname): + if isinstance(qname, DNSLabel): + self._qname = qname + else: + self._qname = DNSLabel(qname) + + def get_qname(self): + return self._qname + + qname = property(get_qname, set_qname) + + def pack(self, buffer): + buffer.encode_name(self.qname) + buffer.pack("!HH", self.qtype, self.qclass) + + def __repr__(self): + return "" % ( + self.qname, + QTYPE.get(self.qtype), + CLASS.get(self.qclass), + ) + + __str__ = __repr__ + + +class RR(object): + rtype = H("rtype") + rclass = H("rclass") + ttl = I("ttl") + rdlength = H("rdlength") + + @classmethod + def parse(cls, buffer): + try: + rname = buffer.decode_name() + rtype, rclass, ttl, rdlength = buffer.unpack("!HHIH") + if rdlength: + rdata = RDMAP.get(QTYPE.get(rtype), RD).parse(buffer, rdlength) + else: + rdata = "" + return cls(rname, rtype, rclass, ttl, rdata) + except (BufferError, BimapError) as e: + raise DNSError("Error unpacking RR [offset=%d]: %s" % (buffer.offset, e)) + + def __init__(self, rname=None, rtype=1, rclass=1, ttl=0, rdata=None): + self.rname = rname + self.rtype = rtype + self.rclass = rclass + self.ttl = ttl + self.rdata = rdata + + def set_rname(self, rname): + if isinstance(rname, DNSLabel): + self._rname = rname + else: + self._rname = DNSLabel(rname) + + def get_rname(self): + return self._rname + + rname = property(get_rname, set_rname) + + def pack(self, buffer): + buffer.encode_name(self.rname) + buffer.pack("!HHI", self.rtype, self.rclass, self.ttl) + rdlength_ptr = buffer.offset + buffer.pack("!H", 0) + start = buffer.offset + self.rdata.pack(buffer) + end = buffer.offset + buffer.update(rdlength_ptr, "!H", end - start) + + def __repr__(self): + return "" % ( + self.rname, + QTYPE.get(self.rtype), + CLASS.get(self.rclass), + self.ttl, + self.rdata, + ) + + __str__ = __repr__ + + +class RD(object): + @classmethod + def parse(cls, buffer, length): + try: + data = buffer.get(length) + return cls(data) + except (BufferError, BimapError) as e: + raise DNSError("Error unpacking RD [offset=%d]: %s" % (buffer.offset, e)) + + def __init__(self, data=b""): + check_bytes("data", data) + self.data = bytes(data) + + def pack(self, buffer): + buffer.append(self.data) + + def __repr__(self): + if len(self.data) > 0: + return "\\# %d %s" % ( + len(self.data), + binascii.hexlify(self.data).decode().upper(), + ) + else: + return "\\# 0" + + attrs = ("data",) + + +def _force_bytes(x): + if isinstance(x, bytes): + return x + else: + return x.encode() + + +class TXT(RD): + @classmethod + def parse(cls, buffer, length): + try: + data = list() + start_bo = buffer.offset + now_length = 0 + while buffer.offset < start_bo + length: + (txtlength,) = buffer.unpack("!B") + + if now_length + txtlength < length: + now_length += txtlength + data.append(buffer.get(txtlength)) + else: + raise DNSError( + "Invalid TXT record: len(%d) > RD len(%d)" % (txtlength, length) + ) + return cls(data) + except (BufferError, BimapError) as e: + raise DNSError("Error unpacking TXT [offset=%d]: %s" % (buffer.offset, e)) + + def __init__(self, data): + if type(data) in (tuple, list): + self.data = [_force_bytes(x) for x in data] + else: + self.data = [_force_bytes(data)] + if any([len(x) > 255 for x in self.data]): + raise DNSError("TXT record too long: %s" % self.data) + + def pack(self, buffer): + for ditem in self.data: + if len(ditem) > 255: + raise DNSError("TXT record too long: %s" % ditem) + buffer.pack("!B", len(ditem)) + buffer.append(ditem) + + def __repr__(self): + return ",".join([repr(x) for x in self.data]) + + +class A(RD): + + data = IP4("data") + + @classmethod + def parse(cls, buffer, length): + try: + data = buffer.unpack("!BBBB") + return cls(data) + except (BufferError, BimapError) as e: + raise DNSError("Error unpacking A [offset=%d]: %s" % (buffer.offset, e)) + + def __init__(self, data): + if type(data) in (tuple, list): + self.data = tuple(data) + else: + self.data = tuple(map(int, data.rstrip(".").split("."))) + + def pack(self, buffer): + buffer.pack("!BBBB", *self.data) + + def __repr__(self): + return "%d.%d.%d.%d" % self.data + + +def _parse_ipv6(a): + l, _, r = a.partition("::") + l_groups = list(chain(*[divmod(int(x, 16), 256) for x in l.split(":") if x])) + r_groups = list(chain(*[divmod(int(x, 16), 256) for x in r.split(":") if x])) + zeros = [0] * (16 - len(l_groups) - len(r_groups)) + return tuple(l_groups + zeros + r_groups) + + +def _format_ipv6(a): + left = [] + right = [] + current = "left" + for i in range(0, 16, 2): + group = (a[i] << 8) + a[i + 1] + if current == "left": + if group == 0 and i < 14: + if (a[i + 2] << 8) + a[i + 3] == 0: + current = "right" + else: + left.append("0") + else: + left.append("%x" % group) + else: + if group == 0 and len(right) == 0: + pass + else: + right.append("%x" % group) + if len(left) < 8: + return ":".join(left) + "::" + ":".join(right) + else: + return ":".join(left) + + +class AAAA(RD): + data = IP6("data") + + @classmethod + def parse(cls, buffer, length): + try: + data = buffer.unpack("!16B") + return cls(data) + except (BufferError, BimapError) as e: + raise DNSError("Error unpacking AAAA [offset=%d]: %s" % (buffer.offset, e)) + + def __init__(self, data): + if type(data) in (tuple, list): + self.data = tuple(data) + else: + self.data = _parse_ipv6(data) + + def pack(self, buffer): + buffer.pack("!16B", *self.data) + + def __repr__(self): + return _format_ipv6(self.data) + + +class CNAME(RD): + @classmethod + def parse(cls, buffer, length): + try: + label = buffer.decode_name() + return cls(label) + except (BufferError, BimapError) as e: + raise DNSError("Error unpacking CNAME [offset=%d]: %s" % (buffer.offset, e)) + + def __init__(self, label=None): + self.label = label + + def set_label(self, label): + if isinstance(label, DNSLabel): + self._label = label + else: + self._label = DNSLabel(label) + + def get_label(self): + return self._label + + label = property(get_label, set_label) + + def pack(self, buffer): + buffer.encode_name(self.label) + + def __repr__(self): + return "%s" % (self.label) + + attrs = ("label",) + + +class PTR(CNAME): + pass + + +class SRV(RD): + priority = H("priority") + weight = H("weight") + port = H("port") + + @classmethod + def parse(cls, buffer, length): + try: + priority, weight, port = buffer.unpack("!HHH") + target = buffer.decode_name() + return cls(priority, weight, port, target) + except (BufferError, BimapError) as e: + raise DNSError("Error unpacking SRV [offset=%d]: %s" % (buffer.offset, e)) + + def __init__(self, priority=0, weight=0, port=0, target=None): + self.priority = priority + self.weight = weight + self.port = port + self.target = target + + def set_target(self, target): + if isinstance(target, DNSLabel): + self._target = target + else: + self._target = DNSLabel(target) + + def get_target(self): + return self._target + + target = property(get_target, set_target) + + def pack(self, buffer): + buffer.pack("!HHH", self.priority, self.weight, self.port) + buffer.encode_name(self.target) + + def __repr__(self): + return "%d %d %d %s" % (self.priority, self.weight, self.port, self.target) + + attrs = ("priority", "weight", "port", "target") + + +def decode_type_bitmap(type_bitmap): + rrlist = [] + buf = DNSBuffer(type_bitmap) + while buf.remaining(): + winnum, winlen = buf.unpack("BB") + bitmap = bytearray(buf.get(winlen)) + for (pos, value) in enumerate(bitmap): + for i in range(8): + if (value << i) & 0x80: + bitpos = (256 * winnum) + (8 * pos) + i + rrlist.append(QTYPE[bitpos]) + return rrlist + + +def encode_type_bitmap(rrlist): + rrlist = sorted([getattr(QTYPE, rr) for rr in rrlist]) + buf = DNSBuffer() + curWindow = rrlist[0] // 256 + bitmap = bytearray(32) + n = len(rrlist) - 1 + for i, rr in enumerate(rrlist): + v = rr - curWindow * 256 + bitmap[v // 8] |= 1 << (7 - v % 8) + + if i == n or rrlist[i + 1] >= (curWindow + 1) * 256: + while bitmap[-1] == 0: + bitmap = bitmap[:-1] + buf.pack("BB", curWindow, len(bitmap)) + buf.append(bitmap) + + if i != n: + curWindow = rrlist[i + 1] // 256 + bitmap = bytearray(32) + + return buf.data + + +class NSEC(RD): + @classmethod + def parse(cls, buffer, length): + try: + end = buffer.offset + length + name = buffer.decode_name() + rrlist = decode_type_bitmap(buffer.get(end - buffer.offset)) + return cls(name, rrlist) + except (BufferError, BimapError) as e: + raise DNSError("Error unpacking NSEC [offset=%d]: %s" % (buffer.offset, e)) + + def __init__(self, label, rrlist): + self.label = label + self.rrlist = rrlist + + def set_label(self, label): + if isinstance(label, DNSLabel): + self._label = label + else: + self._label = DNSLabel(label) + + def get_label(self): + return self._label + + label = property(get_label, set_label) + + def pack(self, buffer): + buffer.encode_name_nocompress(self.label) + buffer.append(encode_type_bitmap(self.rrlist)) + + def __repr__(self): + return "%s %s" % (self.label, " ".join(self.rrlist)) + + attrs = ("label", "rrlist") + + +RDMAP = {"A": A, "AAAA": AAAA, "TXT": TXT, "PTR": PTR, "SRV": SRV, "NSEC": NSEC} diff --git a/copyparty/stolen/dnslib/label.py b/copyparty/stolen/dnslib/label.py new file mode 100644 index 00000000..dc89e56a --- /dev/null +++ b/copyparty/stolen/dnslib/label.py @@ -0,0 +1,154 @@ +# coding: utf-8 + +from __future__ import print_function + +import fnmatch, re + +from .bit import get_bits, set_bits +from .buffer import Buffer, BufferError + +LDH = set(range(33, 127)) +ESCAPE = re.compile(r"\\([0-9][0-9][0-9])") + + +class DNSLabelError(Exception): + pass + + +class DNSLabel(object): + def __init__(self, label): + if type(label) == DNSLabel: + self.label = label.label + elif type(label) in (list, tuple): + self.label = tuple(label) + else: + if not label or label in (b".", "."): + self.label = () + elif type(label) is not bytes: + if type("") != type(b""): + + label = ESCAPE.sub(lambda m: chr(int(m[1])), label) + self.label = tuple(label.encode("idna").rstrip(b".").split(b".")) + else: + if type("") == type(b""): + + label = ESCAPE.sub(lambda m: chr(int(m.groups()[0])), label) + self.label = tuple(label.rstrip(b".").split(b".")) + + def add(self, name): + new = DNSLabel(name) + if self.label: + new.label += self.label + return new + + def idna(self): + return ".".join([s.decode("idna") for s in self.label]) + "." + + def _decode(self, s): + if set(s).issubset(LDH): + + return s.decode() + else: + + return "".join([(chr(c) if (c in LDH) else "\\%03d" % c) for c in s]) + + def __str__(self): + return ".".join([self._decode(bytearray(s)) for s in self.label]) + "." + + def __repr__(self): + return "" % str(self) + + def __hash__(self): + return hash(tuple(map(lambda x: x.lower(), self.label))) + + def __ne__(self, other): + return not self == other + + def __eq__(self, other): + if type(other) != DNSLabel: + return self.__eq__(DNSLabel(other)) + else: + return [l.lower() for l in self.label] == [l.lower() for l in other.label] + + def __len__(self): + return len(b".".join(self.label)) + + +class DNSBuffer(Buffer): + def __init__(self, data=b""): + super(DNSBuffer, self).__init__(data) + self.names = {} + + def decode_name(self, last=-1): + label = [] + done = False + while not done: + (length,) = self.unpack("!B") + if get_bits(length, 6, 2) == 3: + + self.offset -= 1 + pointer = get_bits(self.unpack("!H")[0], 0, 14) + save = self.offset + if last == save: + raise BufferError( + "Recursive pointer in DNSLabel [offset=%d,pointer=%d,length=%d]" + % (self.offset, pointer, len(self.data)) + ) + if pointer < self.offset: + self.offset = pointer + else: + + raise BufferError( + "Invalid pointer in DNSLabel [offset=%d,pointer=%d,length=%d]" + % (self.offset, pointer, len(self.data)) + ) + label.extend(self.decode_name(save).label) + self.offset = save + done = True + else: + if length > 0: + l = self.get(length) + try: + l.decode() + except UnicodeDecodeError: + raise BufferError("Invalid label <%s>" % l) + label.append(l) + else: + done = True + return DNSLabel(label) + + def encode_name(self, name): + if not isinstance(name, DNSLabel): + name = DNSLabel(name) + if len(name) > 253: + raise DNSLabelError("Domain label too long: %r" % name) + name = list(name.label) + while name: + if tuple(name) in self.names: + + pointer = self.names[tuple(name)] + pointer = set_bits(pointer, 3, 14, 2) + self.pack("!H", pointer) + return + else: + self.names[tuple(name)] = self.offset + element = name.pop(0) + if len(element) > 63: + raise DNSLabelError("Label component too long: %r" % element) + self.pack("!B", len(element)) + self.append(element) + self.append(b"\x00") + + def encode_name_nocompress(self, name): + if not isinstance(name, DNSLabel): + name = DNSLabel(name) + if len(name) > 253: + raise DNSLabelError("Domain label too long: %r" % name) + name = list(name.label) + while name: + element = name.pop(0) + if len(element) > 63: + raise DNSLabelError("Label component too long: %r" % element) + self.pack("!B", len(element)) + self.append(element) + self.append(b"\x00") diff --git a/copyparty/stolen/dnslib/lex.py b/copyparty/stolen/dnslib/lex.py new file mode 100644 index 00000000..9517bffa --- /dev/null +++ b/copyparty/stolen/dnslib/lex.py @@ -0,0 +1,105 @@ +# coding: utf-8 + +from __future__ import print_function + +import collections + +try: + from StringIO import StringIO +except ImportError: + from io import StringIO + + +class Lexer(object): + + escape_chars = "\\" + escape = {"n": "\n", "t": "\t", "r": "\r"} + + def __init__(self, f, debug=False): + if hasattr(f, "read"): + self.f = f + elif type(f) == str: + self.f = StringIO(f) + elif type(f) == bytes: + self.f = StringIO(f.decode()) + else: + raise ValueError("Invalid input") + self.debug = debug + self.q = collections.deque() + self.state = self.lexStart + self.escaped = False + self.eof = False + + def __iter__(self): + return self.parse() + + def next_token(self): + if self.debug: + print("STATE", self.state) + (tok, self.state) = self.state() + return tok + + def parse(self): + while self.state is not None and not self.eof: + tok = self.next_token() + if tok: + yield tok + + def read(self, n=1): + s = "" + while self.q and n > 0: + s += self.q.popleft() + n -= 1 + s += self.f.read(n) + if s == "": + self.eof = True + if self.debug: + print("Read: >%s<" % repr(s)) + return s + + def peek(self, n=1): + s = "" + i = 0 + while len(self.q) > i and n > 0: + s += self.q[i] + i += 1 + n -= 1 + r = self.f.read(n) + if n > 0 and r == "": + self.eof = True + self.q.extend(r) + if self.debug: + print("Peek : >%s<" % repr(s + r)) + return s + r + + def pushback(self, s): + p = collections.deque(s) + p.extend(self.q) + self.q = p + + def readescaped(self): + c = self.read(1) + if c in self.escape_chars: + self.escaped = True + n = self.peek(3) + if n.isdigit(): + n = self.read(3) + if self.debug: + print("Escape: >%s<" % n) + return chr(int(n, 8)) + elif n[0] in "x": + x = self.read(3) + if self.debug: + print("Escape: >%s<" % x) + return chr(int(x[1:], 16)) + else: + c = self.read(1) + if self.debug: + print("Escape: >%s<" % c) + return self.escape.get(c, c) + else: + self.escaped = False + return c + + def lexStart(self): + return (None, None) diff --git a/copyparty/stolen/dnslib/ranges.py b/copyparty/stolen/dnslib/ranges.py new file mode 100644 index 00000000..4cf96f5c --- /dev/null +++ b/copyparty/stolen/dnslib/ranges.py @@ -0,0 +1,81 @@ +# coding: utf-8 + +import sys + +if sys.version < "3": + int_types = ( + int, + long, + ) + byte_types = (str, bytearray) +else: + int_types = (int,) + byte_types = (bytes, bytearray) + + +def check_instance(name, val, types): + if not isinstance(val, types): + raise ValueError( + "Attribute '%s' must be instance of %s [%s]" % (name, types, type(val)) + ) + + +def check_bytes(name, val): + return check_instance(name, val, byte_types) + + +def range_property(attr, min, max): + def getter(obj): + return getattr(obj, "_%s" % attr) + + def setter(obj, val): + if isinstance(val, int_types) and min <= val <= max: + setattr(obj, "_%s" % attr, val) + else: + raise ValueError( + "Attribute '%s' must be between %d-%d [%s]" % (attr, min, max, val) + ) + + return property(getter, setter) + + +def B(attr): + return range_property(attr, 0, 255) + + +def H(attr): + return range_property(attr, 0, 65535) + + +def I(attr): + return range_property(attr, 0, 4294967295) + + +def ntuple_range(attr, n, min, max): + f = lambda x: isinstance(x, int_types) and min <= x <= max + + def getter(obj): + return getattr(obj, "_%s" % attr) + + def setter(obj, val): + if len(val) != n: + raise ValueError( + "Attribute '%s' must be tuple with %d elements [%s]" % (attr, n, val) + ) + if all(map(f, val)): + setattr(obj, "_%s" % attr, val) + else: + raise ValueError( + "Attribute '%s' elements must be between %d-%d [%s]" + % (attr, min, max, val) + ) + + return property(getter, setter) + + +def IP4(attr): + return ntuple_range(attr, 4, 0, 255) + + +def IP6(attr): + return ntuple_range(attr, 16, 0, 255) diff --git a/copyparty/stolen/ifaddr/README.md b/copyparty/stolen/ifaddr/README.md new file mode 100644 index 00000000..2a25e632 --- /dev/null +++ b/copyparty/stolen/ifaddr/README.md @@ -0,0 +1,5 @@ +`ifaddr` with py2.7 support enabled by make-sfx.sh which strips py3 hints using strip_hints and removes the `^if True:` blocks + +L: BSD-2-Clause +Copyright (c) 2014 Stefan C. Mueller +https://github.com/pydron/ifaddr/ diff --git a/copyparty/stolen/ifaddr/__init__.py b/copyparty/stolen/ifaddr/__init__.py new file mode 100644 index 00000000..1c7ae7af --- /dev/null +++ b/copyparty/stolen/ifaddr/__init__.py @@ -0,0 +1,21 @@ +# coding: utf-8 +from __future__ import print_function, unicode_literals + +""" +L: BSD-2-Clause +Copyright (c) 2014 Stefan C. Mueller +https://github.com/pydron/ifaddr/tree/0.2.0 +""" + +import os + +from ._shared import Adapter, IP + +if os.name == "nt": + from ._win32 import get_adapters +elif os.name == "posix": + from ._posix import get_adapters +else: + raise RuntimeError("Unsupported Operating System: %s" % os.name) + +__all__ = ["Adapter", "IP", "get_adapters"] diff --git a/copyparty/stolen/ifaddr/_posix.py b/copyparty/stolen/ifaddr/_posix.py new file mode 100644 index 00000000..17e2966f --- /dev/null +++ b/copyparty/stolen/ifaddr/_posix.py @@ -0,0 +1,83 @@ +# coding: utf-8 +from __future__ import print_function, unicode_literals + +import os +import ctypes.util +import ipaddress +import collections +import socket + +if True: # pylint: disable=using-constant-test + from typing import Iterable, Optional + +from . import _shared as shared +from ._shared import U + + +class ifaddrs(ctypes.Structure): + pass + + +ifaddrs._fields_ = [ + ("ifa_next", ctypes.POINTER(ifaddrs)), + ("ifa_name", ctypes.c_char_p), + ("ifa_flags", ctypes.c_uint), + ("ifa_addr", ctypes.POINTER(shared.sockaddr)), + ("ifa_netmask", ctypes.POINTER(shared.sockaddr)), +] + +libc = ctypes.CDLL(ctypes.util.find_library("socket" if os.uname()[0] == "SunOS" else "c"), use_errno=True) # type: ignore + + +def get_adapters(include_unconfigured: bool = False) -> Iterable[shared.Adapter]: + + addr0 = addr = ctypes.POINTER(ifaddrs)() + retval = libc.getifaddrs(ctypes.byref(addr)) + if retval != 0: + eno = ctypes.get_errno() + raise OSError(eno, os.strerror(eno)) + + ips = collections.OrderedDict() + + def add_ip(adapter_name: str, ip: Optional[shared.IP]) -> None: + if adapter_name not in ips: + index = None # type: Optional[int] + try: + # Mypy errors on this when the Windows CI runs: + # error: Module has no attribute "if_nametoindex" + index = socket.if_nametoindex(adapter_name) # type: ignore + except (OSError, AttributeError): + pass + ips[adapter_name] = shared.Adapter( + adapter_name, adapter_name, [], index=index + ) + if ip is not None: + ips[adapter_name].ips.append(ip) + + while addr: + name = addr[0].ifa_name.decode(encoding="UTF-8") + ip_addr = shared.sockaddr_to_ip(addr[0].ifa_addr) + if ip_addr: + if addr[0].ifa_netmask and not addr[0].ifa_netmask[0].sa_familiy: + addr[0].ifa_netmask[0].sa_familiy = addr[0].ifa_addr[0].sa_familiy + netmask = shared.sockaddr_to_ip(addr[0].ifa_netmask) + if isinstance(netmask, tuple): + netmaskStr = U(netmask[0]) + prefixlen = shared.ipv6_prefixlength(ipaddress.IPv6Address(netmaskStr)) + else: + if netmask is None: + t = "sockaddr_to_ip({}) returned None" + raise Exception(t.format(addr[0].ifa_netmask)) + + netmaskStr = U("0.0.0.0/" + netmask) + prefixlen = ipaddress.IPv4Network(netmaskStr).prefixlen + ip = shared.IP(ip_addr, prefixlen, name) + add_ip(name, ip) + else: + if include_unconfigured: + add_ip(name, None) + addr = addr[0].ifa_next + + libc.freeifaddrs(addr0) + + return ips.values() diff --git a/copyparty/stolen/ifaddr/_shared.py b/copyparty/stolen/ifaddr/_shared.py new file mode 100644 index 00000000..d5d02623 --- /dev/null +++ b/copyparty/stolen/ifaddr/_shared.py @@ -0,0 +1,202 @@ +# coding: utf-8 +from __future__ import print_function, unicode_literals + +import sys +import ctypes +import socket +import ipaddress +import platform + +if True: # pylint: disable=using-constant-test + from typing import List, Optional, Tuple, Union, Callable + + +PY2 = sys.version_info < (3,) +if not PY2: + U: Callable[[str], str] = str +else: + U = unicode # noqa: F821 # pylint: disable=undefined-variable,self-assigning-variable + + +class Adapter(object): + """ + Represents a network interface device controller (NIC), such as a + network card. An adapter can have multiple IPs. + + On Linux aliasing (multiple IPs per physical NIC) is implemented + by creating 'virtual' adapters, each represented by an instance + of this class. Each of those 'virtual' adapters can have both + a IPv4 and an IPv6 IP address. + """ + + def __init__( + self, name: str, nice_name: str, ips: List["IP"], index: Optional[int] = None + ) -> None: + + #: Unique name that identifies the adapter in the system. + #: On Linux this is of the form of `eth0` or `eth0:1`, on + #: Windows it is a UUID in string representation, such as + #: `{846EE342-7039-11DE-9D20-806E6F6E6963}`. + self.name = name + + #: Human readable name of the adpater. On Linux this + #: is currently the same as :attr:`name`. On Windows + #: this is the name of the device. + self.nice_name = nice_name + + #: List of :class:`ifaddr.IP` instances in the order they were + #: reported by the system. + self.ips = ips + + #: Adapter index as used by some API (e.g. IPv6 multicast group join). + self.index = index + + def __repr__(self) -> str: + return "Adapter(name={name}, nice_name={nice_name}, ips={ips}, index={index})".format( + name=repr(self.name), + nice_name=repr(self.nice_name), + ips=repr(self.ips), + index=repr(self.index), + ) + + +if True: + # Type of an IPv4 address (a string in "xxx.xxx.xxx.xxx" format) + _IPv4Address = str + + # Type of an IPv6 address (a three-tuple `(ip, flowinfo, scope_id)`) + _IPv6Address = tuple[str, int, int] + + +class IP(object): + """ + Represents an IP address of an adapter. + """ + + def __init__( + self, ip: Union[_IPv4Address, _IPv6Address], network_prefix: int, nice_name: str + ) -> None: + + #: IP address. For IPv4 addresses this is a string in + #: "xxx.xxx.xxx.xxx" format. For IPv6 addresses this + #: is a three-tuple `(ip, flowinfo, scope_id)`, where + #: `ip` is a string in the usual collon separated + #: hex format. + self.ip = ip + + #: Number of bits of the IP that represent the + #: network. For a `255.255.255.0` netmask, this + #: number would be `24`. + self.network_prefix = network_prefix + + #: Human readable name for this IP. + #: On Linux is this currently the same as the adapter name. + #: On Windows this is the name of the network connection + #: as configured in the system control panel. + self.nice_name = nice_name + + @property + def is_IPv4(self) -> bool: + """ + Returns `True` if this IP is an IPv4 address and `False` + if it is an IPv6 address. + """ + return not isinstance(self.ip, tuple) + + @property + def is_IPv6(self) -> bool: + """ + Returns `True` if this IP is an IPv6 address and `False` + if it is an IPv4 address. + """ + return isinstance(self.ip, tuple) + + def __repr__(self) -> str: + return "IP(ip={ip}, network_prefix={network_prefix}, nice_name={nice_name})".format( + ip=repr(self.ip), + network_prefix=repr(self.network_prefix), + nice_name=repr(self.nice_name), + ) + + +if platform.system() == "Darwin" or "BSD" in platform.system(): + + # BSD derived systems use marginally different structures + # than either Linux or Windows. + # I still keep it in `shared` since we can use + # both structures equally. + + class sockaddr(ctypes.Structure): + _fields_ = [ + ("sa_len", ctypes.c_uint8), + ("sa_familiy", ctypes.c_uint8), + ("sa_data", ctypes.c_uint8 * 14), + ] + + class sockaddr_in(ctypes.Structure): + _fields_ = [ + ("sa_len", ctypes.c_uint8), + ("sa_familiy", ctypes.c_uint8), + ("sin_port", ctypes.c_uint16), + ("sin_addr", ctypes.c_uint8 * 4), + ("sin_zero", ctypes.c_uint8 * 8), + ] + + class sockaddr_in6(ctypes.Structure): + _fields_ = [ + ("sa_len", ctypes.c_uint8), + ("sa_familiy", ctypes.c_uint8), + ("sin6_port", ctypes.c_uint16), + ("sin6_flowinfo", ctypes.c_uint32), + ("sin6_addr", ctypes.c_uint8 * 16), + ("sin6_scope_id", ctypes.c_uint32), + ] + +else: + + class sockaddr(ctypes.Structure): # type: ignore + _fields_ = [("sa_familiy", ctypes.c_uint16), ("sa_data", ctypes.c_uint8 * 14)] + + class sockaddr_in(ctypes.Structure): # type: ignore + _fields_ = [ + ("sin_familiy", ctypes.c_uint16), + ("sin_port", ctypes.c_uint16), + ("sin_addr", ctypes.c_uint8 * 4), + ("sin_zero", ctypes.c_uint8 * 8), + ] + + class sockaddr_in6(ctypes.Structure): # type: ignore + _fields_ = [ + ("sin6_familiy", ctypes.c_uint16), + ("sin6_port", ctypes.c_uint16), + ("sin6_flowinfo", ctypes.c_uint32), + ("sin6_addr", ctypes.c_uint8 * 16), + ("sin6_scope_id", ctypes.c_uint32), + ] + + +def sockaddr_to_ip( + sockaddr_ptr: "ctypes.pointer[sockaddr]", +) -> Optional[Union[_IPv4Address, _IPv6Address]]: + if sockaddr_ptr: + if sockaddr_ptr[0].sa_familiy == socket.AF_INET: + ipv4 = ctypes.cast(sockaddr_ptr, ctypes.POINTER(sockaddr_in)) + ippacked = bytes(bytearray(ipv4[0].sin_addr)) + ip = U(ipaddress.ip_address(ippacked)) + return ip + elif sockaddr_ptr[0].sa_familiy == socket.AF_INET6: + ipv6 = ctypes.cast(sockaddr_ptr, ctypes.POINTER(sockaddr_in6)) + flowinfo = ipv6[0].sin6_flowinfo + ippacked = bytes(bytearray(ipv6[0].sin6_addr)) + ip = U(ipaddress.ip_address(ippacked)) + scope_id = ipv6[0].sin6_scope_id + return (ip, flowinfo, scope_id) + return None + + +def ipv6_prefixlength(address: ipaddress.IPv6Address) -> int: + prefix_length = 0 + for i in range(address.max_prefixlen): + if int(address) >> i & 1: + prefix_length = prefix_length + 1 + return prefix_length diff --git a/copyparty/stolen/ifaddr/_win32.py b/copyparty/stolen/ifaddr/_win32.py new file mode 100644 index 00000000..db0d3d40 --- /dev/null +++ b/copyparty/stolen/ifaddr/_win32.py @@ -0,0 +1,135 @@ +# coding: utf-8 +from __future__ import print_function, unicode_literals + +import ctypes +from ctypes import wintypes + +if True: # pylint: disable=using-constant-test + from typing import Iterable, List + +from . import _shared as shared + +NO_ERROR = 0 +ERROR_BUFFER_OVERFLOW = 111 +MAX_ADAPTER_NAME_LENGTH = 256 +MAX_ADAPTER_DESCRIPTION_LENGTH = 128 +MAX_ADAPTER_ADDRESS_LENGTH = 8 +AF_UNSPEC = 0 + + +class SOCKET_ADDRESS(ctypes.Structure): + _fields_ = [ + ("lpSockaddr", ctypes.POINTER(shared.sockaddr)), + ("iSockaddrLength", wintypes.INT), + ] + + +class IP_ADAPTER_UNICAST_ADDRESS(ctypes.Structure): + pass + + +IP_ADAPTER_UNICAST_ADDRESS._fields_ = [ + ("Length", wintypes.ULONG), + ("Flags", wintypes.DWORD), + ("Next", ctypes.POINTER(IP_ADAPTER_UNICAST_ADDRESS)), + ("Address", SOCKET_ADDRESS), + ("PrefixOrigin", ctypes.c_uint), + ("SuffixOrigin", ctypes.c_uint), + ("DadState", ctypes.c_uint), + ("ValidLifetime", wintypes.ULONG), + ("PreferredLifetime", wintypes.ULONG), + ("LeaseLifetime", wintypes.ULONG), + ("OnLinkPrefixLength", ctypes.c_uint8), +] + + +class IP_ADAPTER_ADDRESSES(ctypes.Structure): + pass + + +IP_ADAPTER_ADDRESSES._fields_ = [ + ("Length", wintypes.ULONG), + ("IfIndex", wintypes.DWORD), + ("Next", ctypes.POINTER(IP_ADAPTER_ADDRESSES)), + ("AdapterName", ctypes.c_char_p), + ("FirstUnicastAddress", ctypes.POINTER(IP_ADAPTER_UNICAST_ADDRESS)), + ("FirstAnycastAddress", ctypes.c_void_p), + ("FirstMulticastAddress", ctypes.c_void_p), + ("FirstDnsServerAddress", ctypes.c_void_p), + ("DnsSuffix", ctypes.c_wchar_p), + ("Description", ctypes.c_wchar_p), + ("FriendlyName", ctypes.c_wchar_p), +] + + +iphlpapi = ctypes.windll.LoadLibrary("Iphlpapi") # type: ignore + + +def enumerate_interfaces_of_adapter( + nice_name: str, address: IP_ADAPTER_UNICAST_ADDRESS +) -> Iterable[shared.IP]: + + # Iterate through linked list and fill list + addresses = [] # type: List[IP_ADAPTER_UNICAST_ADDRESS] + while True: + addresses.append(address) + if not address.Next: + break + address = address.Next[0] + + for address in addresses: + ip = shared.sockaddr_to_ip(address.Address.lpSockaddr) + if ip is None: + t = "sockaddr_to_ip({}) returned None" + raise Exception(t.format(address.Address.lpSockaddr)) + + network_prefix = address.OnLinkPrefixLength + yield shared.IP(ip, network_prefix, nice_name) + + +def get_adapters(include_unconfigured: bool = False) -> Iterable[shared.Adapter]: + + # Call GetAdaptersAddresses() with error and buffer size handling + + addressbuffersize = wintypes.ULONG(15 * 1024) + retval = ERROR_BUFFER_OVERFLOW + while retval == ERROR_BUFFER_OVERFLOW: + addressbuffer = ctypes.create_string_buffer(addressbuffersize.value) + retval = iphlpapi.GetAdaptersAddresses( + wintypes.ULONG(AF_UNSPEC), + wintypes.ULONG(0), + None, + ctypes.byref(addressbuffer), + ctypes.byref(addressbuffersize), + ) + if retval != NO_ERROR: + raise ctypes.WinError() # type: ignore + + # Iterate through adapters fill array + address_infos = [] # type: List[IP_ADAPTER_ADDRESSES] + address_info = IP_ADAPTER_ADDRESSES.from_buffer(addressbuffer) + while True: + address_infos.append(address_info) + if not address_info.Next: + break + address_info = address_info.Next[0] + + # Iterate through unicast addresses + result = [] # type: List[shared.Adapter] + for adapter_info in address_infos: + + # We don't expect non-ascii characters here, so encoding shouldn't matter + name = adapter_info.AdapterName.decode() + nice_name = adapter_info.Description + index = adapter_info.IfIndex + + if adapter_info.FirstUnicastAddress: + ips = enumerate_interfaces_of_adapter( + adapter_info.FriendlyName, adapter_info.FirstUnicastAddress[0] + ) + ips = list(ips) + result.append(shared.Adapter(name, nice_name, ips, index=index)) + elif include_unconfigured: + result.append(shared.Adapter(name, nice_name, [], index=index)) + + return result diff --git a/copyparty/svchub.py b/copyparty/svchub.py index c68c5b1d..e864d898 100644 --- a/copyparty/svchub.py +++ b/copyparty/svchub.py @@ -195,10 +195,17 @@ class SvcHub(object): args.th_poke = min(args.th_poke, args.th_maxage, args.ac_maxage) + zms = "" + if not args.https_only: + zms += "d" + if not args.http_only: + zms += "D" + if args.ftp or args.ftps: from .ftpd import Ftpd self.ftpd = Ftpd(self) + zms += "f" if args.ftp else "F" if args.smb: # impacket.dcerpc is noisy about listen timeouts @@ -210,6 +217,12 @@ class SvcHub(object): self.smbd = SMB(self) socket.setdefaulttimeout(sto) self.smbd.start() + zms += "s" + + if not args.zms: + args.zms = zms + + self.mdns: Any = None # decide which worker impl to use if self.check_mp_enable(): @@ -359,6 +372,15 @@ class SvcHub(object): def run(self) -> None: self.tcpsrv.run() + if getattr(self.args, "zm", False): + try: + from .mdns import MDNS + + self.mdns = MDNS(self) + Daemon(self.mdns.run, "mdns") + except: + self.log("root", "mdns startup failed;\n" + min_ex(), 3) + Daemon(self.thr_httpsrv_up, "sig-hsrv-up2") sigs = [signal.SIGINT, signal.SIGTERM] @@ -464,6 +486,11 @@ class SvcHub(object): ret = 1 try: self.pr("OPYTHAT") + slp = 0.0 + if self.mdns: + Daemon(self.mdns.stop) + slp = time.time() + 1 + self.tcpsrv.shutdown() self.broker.shutdown() self.up2k.shutdown() @@ -482,6 +509,9 @@ class SvcHub(object): Daemon(self.kill9, a=(1,)) self.smbd.stop() + while time.time() < slp: + time.sleep(0.1) + self.pr("nailed it", end="") ret = self.retcode except: diff --git a/copyparty/tcpsrv.py b/copyparty/tcpsrv.py index 7fcc421a..a2fb1b62 100644 --- a/copyparty/tcpsrv.py +++ b/copyparty/tcpsrv.py @@ -25,6 +25,9 @@ if True: if TYPE_CHECKING: from .svchub import SvcHub +if not hasattr(socket, "IPPROTO_IPV6"): + setattr(socket, "IPPROTO_IPV6", 41) + class TcpSrv(object): """ @@ -42,6 +45,7 @@ class TcpSrv(object): self.stopping = False self.srv: list[socket.socket] = [] + self.bound: list[tuple[str, int]] = [] self.nsrv = 0 self.qr = "" pad = False @@ -97,14 +101,22 @@ class TcpSrv(object): if pad: self.log("tcpsrv", "") - ip = "127.0.0.1" - eps = {ip: "local only"} - nonlocals = [x for x in self.args.i if x != ip] + eps = {"127.0.0.1": "local only", "::1": "local only"} + nonlocals = [x for x in self.args.i if x not in [k.split("/")[0] for k in eps]] if nonlocals: - eps = self.detect_interfaces(self.args.i) + try: + self.netdevs = self.detect_interfaces(self.args.i) + except: + t = "failed to discover server IP addresses\n" + self.log("tcpsrv", t + min_ex(), 3) + self.netdevs = {} + + eps.update({k.split("/")[0]: v for k, v in self.netdevs.items()}) if not eps: for x in nonlocals: eps[x] = "external" + else: + self.netdevs = {} qr1: dict[str, list[int]] = {} qr2: dict[str, list[int]] = {} @@ -180,6 +192,12 @@ class TcpSrv(object): srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) srv.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) srv.settimeout(None) # < does not inherit, ^ does + + try: + srv.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False) + except: + pass # will create another ipv4 socket instead + try: srv.bind((ip, port)) self.srv.append(srv) @@ -194,8 +212,8 @@ class TcpSrv(object): def run(self) -> None: all_eps = [x.getsockname()[:2] for x in self.srv] - bound = [] - srvs = [] + bound: list[tuple[str, int]] = [] + srvs: list[socket.socket] = [] for srv in self.srv: ip, port = srv.getsockname()[:2] try: @@ -225,6 +243,7 @@ class TcpSrv(object): self.hub.broker.say("listen", srv) self.srv = srvs + self.bound = bound self.nsrv = len(srvs) def shutdown(self) -> None: @@ -370,19 +389,22 @@ class TcpSrv(object): return eps def detect_interfaces(self, listen_ips: list[str]) -> dict[str, str]: - if MACOS: - eps = self.ips_macos() - elif ANYWIN: - eps, off = self.ips_windows_ipconfig() # sees more interfaces + link state - eps.update(self.ips_windows_netsh()) # has better names - for k, v in eps.items(): - if v in off: - eps[k] += ", \033[31mLINK-DOWN" - else: - eps = self.ips_linux() + from .stolen.ifaddr import get_adapters + + nics = get_adapters(True) + eps = {} + for nic in nics: + for nip in nic.ips: + ipa = nip.ip[0] if ":" in str(nip.ip) else nip.ip + sip = "{}/{}".format(ipa, nip.network_prefix) + if sip.startswith("fe80") or sip.startswith("169.254"): + # browsers dont impl linklocal + continue + + eps[sip] = nic.nice_name if "0.0.0.0" not in listen_ips and "::" not in listen_ips: - eps = {k: v for k, v in eps.items() if k in listen_ips} + eps = {k: v for k, v in eps.items() if k.split("/")[0] in listen_ips} try: ext_devs = list(self._extdevs_nix()) @@ -478,7 +500,13 @@ class TcpSrv(object): def _qr(self, t1: dict[str, list[int]], t2: dict[str, list[int]]) -> str: ip = None - for ip in list(t1) + list(t2): + ips = list(t1) + list(t2) + if self.args.zm: + name = self.args.name + ".local" + t1[name] = next(v for v in (t1 or t2).values()) + ips = [name] + ips + + for ip in ips: if ip.startswith(self.args.qri): break ip = "" diff --git a/copyparty/util.py b/copyparty/util.py index 64309165..c889f062 100644 --- a/copyparty/util.py +++ b/copyparty/util.py @@ -24,6 +24,7 @@ import time import traceback from collections import Counter from datetime import datetime +from ipaddress import IPv6Address from queue import Queue @@ -60,11 +61,6 @@ try: except: pass -try: - from ipaddress import IPv6Address -except: - pass - try: HAVE_SQLITE3 = True import sqlite3 # pylint: disable=unused-import # typechk @@ -184,6 +180,9 @@ IMPLICATIONS = [ ["smbw", "smb"], ["smb1", "smb"], ["smb_dbg", "smb"], + ["zmvv", "zmv"], + ["zmv", "zm"], + ["zms", "zm"], ] @@ -536,6 +535,27 @@ class _LUnrecv(object): Unrecv = _Unrecv +class CachedSet(object): + def __init__(self, maxage: float) -> None: + self.c: dict[Any, float] = {} + self.maxage = maxage + self.oldest = 0.0 + + def add(self, v: Any) -> None: + self.c[v] = time.time() + + def cln(self) -> None: + now = time.time() + if now - self.oldest < self.maxage: + return + + c = self.c = {k: v for k, v in self.c.items() if now - v < self.maxage} + try: + self.oldest = c[min(c, key=c.get)] + except: + self.oldest = now + + class FHC(object): class CE(object): def __init__(self, fh: typing.BinaryIO) -> None: @@ -836,7 +856,7 @@ class Garda(object): if not self.lim: return 0, ip - if ":" in ip and not PY2: + if ":" in ip: # assume /64 clients; drop 4 groups ip = IPv6Address(ip).exploded[:-20] @@ -1603,7 +1623,7 @@ def exclude_dotfiles(filepaths: list[str]) -> list[str]: return [x for x in filepaths if not x.split("/")[-1].startswith(".")] -def _ipnorm3(ip: str) -> str: +def ipnorm(ip: str) -> str: if ":" in ip: # assume /64 clients; drop 4 groups return IPv6Address(ip).exploded[:-20] @@ -1611,9 +1631,6 @@ def _ipnorm3(ip: str) -> str: return ip -ipnorm = _ipnorm3 if not PY2 else unicode - - def http_ts(ts: int) -> str: file_dt = datetime.utcfromtimestamp(ts) return file_dt.strftime(HTTP_TS_FMT) diff --git a/docs/devnotes.md b/docs/devnotes.md index 268d1219..8f19e9ff 100644 --- a/docs/devnotes.md +++ b/docs/devnotes.md @@ -48,6 +48,17 @@ hashwasm would solve the streaming issue but reduces hashing speed for sha512 (x * blake2 might be a better choice since xxh is non-cryptographic, but that gets ~15 MiB/s on slower androids +## assumptions + +### mdns + +* outgoing replies will always fit in one packet +* if a client mentions any of our services, assume it's not missing any +* always answer with all services, even if the client only asked for a few +* not-impl: probe tiebreaking (too complicated) +* not-impl: unicast listen (assume avahi took it) + + # sfx repack reduce the size of an sfx by removing features diff --git a/docs/lics.txt b/docs/lics.txt index 24342120..434bd361 100644 --- a/docs/lics.txt +++ b/docs/lics.txt @@ -6,17 +6,21 @@ L: MIT https://github.com/pallets/jinja/ C: 2007 Pallets -L: BSD 3-Clause +L: BSD 3-Clause https://github.com/pallets/markupsafe/ C: 2010 Pallets -L: BSD 3-Clause +L: BSD 3-Clause + +https://github.com/paulc/dnslib/ +C: 2010-2017 Paul Chakravarti +L: BSD 2-Clause https://github.com/giampaolo/pyftpdlib/ -C: 2007 Giampaolo Rodola' +C: 2007 Giampaolo Rodola L: MIT -https://github.com/nayuki/QR-Code-generator +https://github.com/nayuki/QR-Code-generator/ C: Project Nayuki L: MIT diff --git a/scripts/genlic.sh b/scripts/genlic.sh index 6c549371..173506db 100755 --- a/scripts/genlic.sh +++ b/scripts/genlic.sh @@ -18,6 +18,12 @@ f=../build/isc.txt awk '/div>/{o=0}o>2;o{o++}/;OWNER/{o=1}' | awk '{gsub(/<[^>]+>/,"")};/./{b=0}!/./{b++}b>1{next}1' >$f +f=../build/2bsd.txt +[ -e $f ] || + curl https://opensource.org/licenses/BSD-2-Clause | + awk '/div>/{o=0}o>1;o{o++}/HOLDER/{o=1}' | + awk '{gsub(/<[^>]+>/,"")};1' >$f + f=../build/3bsd.txt [ -e $f ] || curl https://opensource.org/licenses/BSD-3-Clause | @@ -33,6 +39,7 @@ f=../build/ofl.txt (sed -r 's/^L: /License: /;s/^C: /Copyright (c) /' <../docs/lics.txt printf '\n\n--- MIT License ---\n\n'; cat ../build/mit.txt printf '\n\n--- ISC License ---\n\n'; cat ../build/isc.txt +printf '\n\n--- BSD 2-Clause License ---\n\n'; cat ../build/2bsd.txt printf '\n\n--- BSD 3-Clause License ---\n\n'; cat ../build/3bsd.txt printf '\n\n--- SIL Open Font License v1.1 ---\n\n'; cat ../build/ofl.txt ) | diff --git a/scripts/make-sfx.sh b/scripts/make-sfx.sh index d1fe5ff7..f771f9b6 100755 --- a/scripts/make-sfx.sh +++ b/scripts/make-sfx.sh @@ -27,6 +27,8 @@ help() { exec cat <<'EOF' # # `no-smb` saves ~3.5k by removing the smb / cifs server # +# `no-zm` saves ~k by removing the zeroconf mDNS server +# # _____________________________________________________________________ # web features: # @@ -101,6 +103,7 @@ while [ ! -z "$1" ]; do gzz) shift;use_gzz=$1;use_gz=1; ;; no-ftp) no_ftp=1 ; ;; no-smb) no_smb=1 ; ;; + no-zm) no_zm=1 ; ;; no-fnt) no_fnt=1 ; ;; no-hl) no_hl=1 ; ;; no-dd) no_dd=1 ; ;; @@ -136,11 +139,22 @@ tmpdir="$( [ $repack ] && { old="$tmpdir/pe-copyparty.$(id -u)" echo "repack of files in $old" - cp -pR "$old/"*{py2,j2,copyparty} . + cp -pR "$old/"*{py2,py37,j2,copyparty} . cp -pR "$old/"*ftp . || true } [ $repack ] || { + echo collecting ipaddress + f="../build/ipaddress-1.0.23.tar.gz" + [ -e "$f" ] || + (url=https://files.pythonhosted.org/packages/b9/9a/3e9da40ea28b8210dd6504d3fe9fe7e013b62bf45902b458d1cdc3c34ed9/ipaddress-1.0.23.tar.gz; + wget -O$f "$url" || curl -L "$url" >$f) + + tar -zxf $f + mkdir py37 + mv ipaddress-*/ipaddress.py py37/ + rm -rf ipaddress-* + echo collecting jinja2 f="../build/Jinja2-2.11.3.tar.gz" [ -e "$f" ] || @@ -237,6 +251,8 @@ tmpdir="$( awk 'NR<4||NR>27;NR==4{print"# license: https://opensource.org/licenses/ISC\n"}' ../build/$n >copyparty/vend/$n done + rm -f copyparty/stolen/*/README.md + # remove type hints before build instead (cd copyparty; "$pybin" ../../scripts/strip_hints/a.py; rm uh) @@ -322,6 +338,10 @@ rm have rm -f copyparty/smbd.py && sed -ri '/add_argument\("--smb/d' copyparty/__main__.py +[ $no_zm ] && + rm -rf copyparty/mdns.py copyparty/stolen/dnslib && + sed -ri '/add_argument\("--zm/d' copyparty/__main__.py + [ $no_cm ] && { rm -rf copyparty/web/mde.* copyparty/web/deps/easymde* echo h > copyparty/web/mde.html @@ -464,7 +484,7 @@ nf=$(ls -1 "$zdir"/arc.* | wc -l) echo "copying.txt 404 pls rebuild" mv ftp/* j2/* copyparty/vend/* . - rm -rf ftp j2 py2 copyparty/vend + rm -rf ftp j2 py2 py37 copyparty/vend (cd copyparty; tar -cvf z.tar $t; rm -rf $t) cd .. pyoxidizer build --release --target-triple $tgt @@ -481,7 +501,7 @@ nf=$(ls -1 "$zdir"/arc.* | wc -l) echo gen tarlist -for d in copyparty j2 py2 ftp; do find $d -type f; done | # strip_hints +for d in copyparty j2 py2 py37 ftp; do find $d -type f; done | # strip_hints sed -r 's/(.*)\.(.*)/\2 \1/' | LC_ALL=C sort | sed -r 's/([^ ]*) (.*)/\2.\1/' | grep -vE '/list1?$' > list1 diff --git a/scripts/sfx.ls b/scripts/sfx.ls index aacaa4d3..832ab6a0 100644 --- a/scripts/sfx.ls +++ b/scripts/sfx.ls @@ -18,6 +18,7 @@ copyparty/httpcli.py, copyparty/httpconn.py, copyparty/httpsrv.py, copyparty/ico.py, +copyparty/mdns.py, copyparty/mtag.py, copyparty/res, copyparty/res/COPYING.txt, @@ -26,6 +27,20 @@ copyparty/smbd.py, copyparty/star.py, copyparty/stolen, copyparty/stolen/__init__.py, +copyparty/stolen/dnslib, +copyparty/stolen/dnslib/__init__.py, +copyparty/stolen/dnslib/bimap.py, +copyparty/stolen/dnslib/bit.py, +copyparty/stolen/dnslib/buffer.py, +copyparty/stolen/dnslib/dns.py, +copyparty/stolen/dnslib/label.py, +copyparty/stolen/dnslib/lex.py, +copyparty/stolen/dnslib/ranges.py, +copyparty/stolen/ifaddr, +copyparty/stolen/ifaddr/__init__.py, +copyparty/stolen/ifaddr/_posix.py, +copyparty/stolen/ifaddr/_shared.py, +copyparty/stolen/ifaddr/_win32.py, copyparty/stolen/qrcodegen.py, copyparty/stolen/surrogateescape.py, copyparty/sutil.py, diff --git a/scripts/sfx.py b/scripts/sfx.py index 75f11b6b..4ff57b0c 100644 --- a/scripts/sfx.py +++ b/scripts/sfx.py @@ -28,6 +28,7 @@ CKSUM = None STAMP = None PY2 = sys.version_info < (3,) +PY37 = sys.version_info > (3, 7) WINDOWS = sys.platform in ["win32", "msys"] sys.dont_write_bytecode = True me = os.path.abspath(os.path.realpath(__file__)) @@ -401,7 +402,7 @@ def run(tmp, j2, ftp): t.daemon = True t.start() - ld = (("", ""), (j2, "j2"), (ftp, "ftp"), (not PY2, "py2")) + ld = (("", ""), (j2, "j2"), (ftp, "ftp"), (not PY2, "py2"), (PY37, "py37")) ld = [os.path.join(tmp, b) for a, b in ld if not a] # skip 1