add mdns zeroconf announcer

This commit is contained in:
ed 2022-11-13 20:05:16 +00:00
parent fc0a941508
commit b3eb117e87
29 changed files with 2581 additions and 52 deletions

View file

@ -12,14 +12,14 @@ except:
TYPE_CHECKING = False TYPE_CHECKING = False
if True: if True:
from typing import Any from typing import Any, Callable
PY2 = sys.version_info < (3,) PY2 = sys.version_info < (3,)
if PY2: if not PY2:
unicode: Callable[[str], str] = str
else:
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
unicode = unicode # noqa: F821 # pylint: disable=undefined-variable,self-assigning-variable unicode = unicode # noqa: F821 # pylint: disable=undefined-variable,self-assigning-variable
else:
unicode = str
WINDOWS: Any = ( WINDOWS: Any = (
[int(x) for x in platform.version().split(".")] [int(x) for x in platform.version().split(".")]

View file

@ -9,11 +9,13 @@ __license__ = "MIT"
__url__ = "https://github.com/9001/copyparty/" __url__ = "https://github.com/9001/copyparty/"
import argparse import argparse
import base64
import filecmp import filecmp
import locale import locale
import os import os
import re import re
import shutil import shutil
import socket
import sys import sys
import threading import threading
import time import time
@ -209,6 +211,31 @@ def init_E(E: EnvParams) -> None:
raise raise
def get_srvname() -> str:
try:
ret: str = unicode(socket.gethostname()).split(".")[0].lower()
except:
ret = ""
if ret not in ["", "localhost"]:
return ret
fp = os.path.join(E.cfg, "name.txt")
lprint("using hostname from {}\n".format(fp))
try:
with open(fp, "rb") as f:
ret = f.read().decode("utf-8", "replace").strip()
except:
ret = ""
while len(ret) < 7:
ret += base64.b32encode(os.urandom(4))[:7].decode("utf-8").lower()
ret = re.sub("[234567=]", "", ret)[:7]
with open(fp, "wb") as f:
f.write(ret.encode("utf-8") + b"\n")
return ret
def ensure_locale() -> None: def ensure_locale() -> None:
for x in [ for x in [
"en_US.UTF-8", "en_US.UTF-8",
@ -431,6 +458,8 @@ def run_argparse(
tty = os.environ.get("TERM", "").lower() == "linux" tty = os.environ.get("TERM", "").lower() == "linux"
srvname = get_srvname()
sects = [ sects = [
[ [
"accounts", "accounts",
@ -584,6 +613,7 @@ def run_argparse(
ap2.add_argument("-mcr", metavar="SEC", type=int, default=60, help="md-editor mod-chk rate") ap2.add_argument("-mcr", metavar="SEC", type=int, default=60, help="md-editor mod-chk rate")
ap2.add_argument("--urlform", metavar="MODE", type=u, default="print,get", help="how to handle url-form POSTs; see --help-urlform") ap2.add_argument("--urlform", metavar="MODE", type=u, default="print,get", help="how to handle url-form POSTs; see --help-urlform")
ap2.add_argument("--wintitle", metavar="TXT", type=u, default="cpp @ $pub", help="window title, for example [\033[32m$ip-10.1.2.\033[0m] or [\033[32m$ip-]") ap2.add_argument("--wintitle", metavar="TXT", type=u, default="cpp @ $pub", help="window title, for example [\033[32m$ip-10.1.2.\033[0m] or [\033[32m$ip-]")
ap2.add_argument("--name", metavar="TXT", type=str, default=srvname, help="server name (displayed topleft in browser and in mDNS)")
ap2.add_argument("--license", action="store_true", help="show licenses and exit") ap2.add_argument("--license", action="store_true", help="show licenses and exit")
ap2.add_argument("--version", action="store_true", help="show versions and exit") ap2.add_argument("--version", action="store_true", help="show versions and exit")
@ -630,6 +660,21 @@ def run_argparse(
ap2.add_argument("--ssl-dbg", action="store_true", help="dump some tls info") ap2.add_argument("--ssl-dbg", action="store_true", help="dump some tls info")
ap2.add_argument("--ssl-log", metavar="PATH", type=u, help="log master secrets for later decryption in wireshark") ap2.add_argument("--ssl-log", metavar="PATH", type=u, help="log master secrets for later decryption in wireshark")
ap2 = ap.add_argument_group("Zeroconf options")
ap2.add_argument("--zm", action="store_true", help="announce the enabled protocols over mDNS (multicast DNS-SD) -- compatible with KDE, gnome, macOS, ...")
ap2.add_argument("--zm4", action="store_true", help="IPv4 only -- try this if some clients don't work")
ap2.add_argument("--zm6", action="store_true", help="IPv6 only")
ap2.add_argument("--zmv", action="store_true", help="verbose mdns")
ap2.add_argument("--zmvv", action="store_true", help="verboser mdns")
ap2.add_argument("--zms", metavar="dhf", type=str, default="", help="list of services to announce -- d=webdav h=http f=ftp s=smb -- lowercase=plaintext uppercase=TLS -- default: all enabled services except http/https (\033[32mDdfs\033[0m if \033[33m--ftp\033[0m and \033[33m--smb\033[0m is set)")
ap2.add_argument("--zm-ld", metavar="PATH", type=str, default="", help="link a specific folder for webdav shares")
ap2.add_argument("--zm-lh", metavar="PATH", type=str, default="", help="link a specific folder for http shares")
ap2.add_argument("--zm-lf", metavar="PATH", type=str, default="", help="link a specific folder for ftp shares")
ap2.add_argument("--zm-ls", metavar="PATH", type=str, default="", help="link a specific folder for smb shares")
ap2.add_argument("--zm-mnic", action="store_true", help="merge NICs which share subnets; assume that same subnet means same network")
ap2.add_argument("--zm-msub", action="store_true", help="merge subnets on each NIC -- always enabled for ipv6 -- reduces network load, but gnome-gvfs clients may stop working")
ap2.add_argument("--mc-hop", metavar="SEC", type=int, default=0, help="rejoin multicast groups every SEC seconds (workaround for some switches/routers which cause mDNS to suddenly stop working after some time); try [\033[32m300\033[0m] or [\033[32m180\033[0m]")
ap2 = ap.add_argument_group('FTP options') ap2 = ap.add_argument_group('FTP options')
ap2.add_argument("--ftp", metavar="PORT", type=int, help="enable FTP server on PORT, for example \033[32m3921") ap2.add_argument("--ftp", metavar="PORT", type=int, help="enable FTP server on PORT, for example \033[32m3921")
ap2.add_argument("--ftps", metavar="PORT", type=int, help="enable FTPS server on PORT, for example \033[32m3990") ap2.add_argument("--ftps", metavar="PORT", type=int, help="enable FTPS server on PORT, for example \033[32m3990")
@ -898,6 +943,7 @@ def main(argv: Optional[list[str]] = None) -> None:
for fmtr in [RiceFormatter, RiceFormatter, Dodge11874, BasicDodge11874]: for fmtr in [RiceFormatter, RiceFormatter, Dodge11874, BasicDodge11874]:
try: try:
al = run_argparse(argv, fmtr, retry, nc) al = run_argparse(argv, fmtr, retry, nc)
break
except SystemExit: except SystemExit:
raise raise
except: except:

View file

@ -11,7 +11,6 @@ import itertools
import json import json
import os import os
import re import re
import socket
import stat import stat
import string import string
import threading # typechk import threading # typechk
@ -27,11 +26,6 @@ try:
except: except:
pass pass
try:
from ipaddress import IPv6Address
except:
pass
from .__init__ import ANYWIN, PY2, TYPE_CHECKING, EnvParams, unicode from .__init__ import ANYWIN, PY2, TYPE_CHECKING, EnvParams, unicode
from .authsrv import VFS # typechk from .authsrv import VFS # typechk
from .bos import bos from .bos import bos
@ -3030,7 +3024,7 @@ class HttpCli(object):
try: try:
if not self.args.nih: if not self.args.nih:
srv_info.append(unicode(socket.gethostname()).split(".")[0]) srv_info.append(self.args.name)
except: except:
self.log("#wow #whoa") self.log("#wow #whoa")

View file

@ -11,11 +11,6 @@ import time
import queue import queue
try:
from ipaddress import IPv6Address
except:
pass
try: try:
import jinja2 import jinja2
except ImportError: except ImportError:

414
copyparty/mdns.py Normal file
View file

@ -0,0 +1,414 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
import random
import select
import socket
import time
from ipaddress import IPv4Network, IPv6Network
from .__init__ import TYPE_CHECKING
from .__init__ import unicode as U
from .util import CachedSet, Daemon, min_ex
from .multicast import MC_Sck, MCast
from .stolen.dnslib import (
RR,
DNSHeader,
DNSRecord,
DNSQuestion,
QTYPE,
A,
AAAA,
SRV,
PTR,
TXT,
)
from .stolen.dnslib import CLASS as DC
if TYPE_CHECKING:
from .svchub import SvcHub
if True: # pylint: disable=using-constant-test
from typing import Any, Optional, Union
MDNS4 = "224.0.0.251"
MDNS6 = "ff02::fb"
class MDNS_Sck(MC_Sck):
def __init__(
self,
sck: socket.socket,
idx: int,
grp: str,
ip: str,
net: Union[IPv4Network, IPv6Network],
):
super(MDNS_Sck, self).__init__(sck, idx, grp, ip, net)
self.bp_probe = b""
self.bp_ip = b""
self.bp_svc = b""
self.bp_bye = b""
self.last_tx = 0.0
class MDNS(MCast):
def __init__(self, hub: "SvcHub") -> None:
grp4 = "" if hub.args.zm6 else MDNS4
grp6 = "" if hub.args.zm4 else MDNS6
super(MDNS, self).__init__(hub, MDNS_Sck, grp4, grp6, 5353)
self.srv: dict[socket.socket, MDNS_Sck] = {}
self.ttl = 300
self.running = True
zs = self.args.name.lower() + ".local."
zs = zs.encode("ascii", "replace").decode("ascii", "replace")
self.hn = zs.replace("?", "_")
# requester ip -> (response deadline, srv, body):
self.q: dict[str, tuple[float, MDNS_Sck, bytes]] = {}
self.rx4 = CachedSet(0.42) # 3 probes @ 250..500..750 => 500ms span
self.rx6 = CachedSet(0.42)
self.svcs, self.sfqdns = self.build_svcs()
self.probing = 0.0
self.unsolicited: list[float] = [] # scheduled announces on all nics
self.defend: dict[MDNS_Sck, float] = {} # server -> deadline
def log(self, msg: str, c: Union[int, str] = 0) -> None:
self.log_func("mDNS", msg, c)
def build_svcs(self) -> tuple[dict[str, dict[str, Any]], set[str]]:
zms = self.args.zms
http = {"port": 80 if 80 in self.args.p else self.args.p[0]}
https = {"port": 443 if 443 in self.args.p else self.args.p[0]}
webdav = http.copy()
webdavs = https.copy()
webdav["u"] = webdavs["u"] = "u" # KDE requires username
ftp = {"port": (self.args.ftp if "f" in zms else self.args.ftps)}
smb = {"port": self.args.smb_port}
# some gvfs require path
zs = self.args.zm_ld or "/"
if zs:
webdav["path"] = zs
webdavs["path"] = zs
if self.args.zm_lh:
http["path"] = self.args.zm_lh
https["path"] = self.args.zm_lh
if self.args.zm_lf:
ftp["path"] = self.args.zm_lf
if self.args.zm_ls:
smb["path"] = self.args.zm_ls
svcs: dict[str, dict[str, Any]] = {}
if "d" in zms:
svcs["_webdav._tcp.local."] = webdav
if "D" in zms:
svcs["_webdavs._tcp.local."] = webdavs
if "h" in zms:
svcs["_http._tcp.local."] = http
if "H" in zms:
svcs["_https._tcp.local."] = https
if "f" in zms.lower():
svcs["_ftp._tcp.local."] = ftp
if "s" in zms.lower():
svcs["_smb._tcp.local."] = smb
sfqdns: set[str] = set()
for k, v in svcs.items():
name = "{}-c-{}".format(self.args.name, k.split(".")[0][1:])
v["name"] = name
sfqdns.add("{}.{}".format(name, k))
return svcs, sfqdns
def build_replies(self) -> None:
for srv in self.srv.values():
probe = DNSRecord(DNSHeader(0, 0), q=DNSQuestion(self.hn, QTYPE.ANY))
areply = DNSRecord(DNSHeader(0, 0x8400))
sreply = DNSRecord(DNSHeader(0, 0x8400))
bye = DNSRecord(DNSHeader(0, 0x8400))
for ip in srv.ips:
if ":" in ip:
qt = QTYPE.AAAA
ar = {"rclass": DC.F_IN, "rdata": AAAA(ip)}
else:
qt = QTYPE.A
ar = {"rclass": DC.F_IN, "rdata": A(ip)}
r0 = RR(self.hn, qt, ttl=0, **ar)
r120 = RR(self.hn, qt, ttl=120, **ar)
# rfc-10:
# SHOULD rr ttl 120sec for A/AAAA/SRV
# (and recommend 75min for all others)
probe.add_auth(r120)
areply.add_answer(r120)
sreply.add_answer(r120)
bye.add_answer(r0)
for sclass, props in self.svcs.items():
sname = props["name"]
sport = props["port"]
sfqdn = sname + "." + sclass
k = "_services._dns-sd._udp.local."
r = RR(k, QTYPE.PTR, DC.IN, 4500, PTR(sclass))
sreply.add_answer(r)
r = RR(sclass, QTYPE.PTR, DC.IN, 4500, PTR(sfqdn))
sreply.add_answer(r)
r = RR(sfqdn, QTYPE.SRV, DC.F_IN, 120, SRV(0, 0, sport, self.hn))
sreply.add_answer(r)
areply.add_answer(r)
r = RR(sfqdn, QTYPE.SRV, DC.F_IN, 0, SRV(0, 0, sport, self.hn))
bye.add_answer(r)
txts = []
for k in ("u", "path"):
if k not in props:
continue
zb = "{}={}".format(k, props[k]).encode("utf-8")
if len(zb) > 255:
t = "value too long for mdns: [{}]"
raise Exception(t.format(props[k]))
txts.append(zb)
# gvfs really wants txt even if they're empty
r = RR(sfqdn, QTYPE.TXT, DC.F_IN, 4500, TXT(txts))
sreply.add_answer(r)
srv.bp_probe = probe.pack()
srv.bp_ip = areply.pack()
srv.bp_svc = sreply.pack()
srv.bp_bye = bye.pack()
# since all replies are small enough to fit in one packet,
# always send full replies rather than just a/aaaa records
srv.bp_ip = srv.bp_svc
def send_probes(self) -> None:
slp = random.random() * 0.25
for _ in range(3):
time.sleep(slp)
slp = 0.25
if not self.running:
break
if self.args.zmv:
self.log("sending hostname probe...")
# ipv4: need to probe each ip (each server)
# ipv6: only need to probe each set of looped nics
probed6: set[str] = set()
for srv in self.srv.values():
if srv.ip in probed6:
continue
try:
srv.sck.sendto(srv.bp_probe, (srv.grp, 5353))
if srv.v6:
for ip in srv.ips:
probed6.add(ip)
except Exception as ex:
self.log("sendto failed: {} ({})".format(srv.ip, ex), "90")
def run(self) -> None:
bound = self.create_servers()
if not bound:
self.log("failed to announce copyparty services on the network", 3)
return
self.build_replies()
Daemon(self.send_probes)
zf = time.time() + 2
self.probing = zf # cant unicast so give everyone an extra sec
self.unsolicited = [zf, zf + 1, zf + 3, zf + 7] # rfc-8.3
last_hop = time.time()
ihop = self.args.mc_hop
while self.running:
timeout = (
0.02 + random.random() * 0.07
if self.probing or self.q or self.defend or self.unsolicited
else (last_hop + ihop if ihop else 180)
)
rdy = select.select(self.srv, [], [], timeout)
rx: list[socket.socket] = rdy[0] # type: ignore
self.rx4.cln()
self.rx6.cln()
for srv in rx:
buf, addr = srv.recvfrom(4096)
try:
self.eat(buf, addr)
except:
t = "{} \033[33m|{}| {}\n{}".format(
addr, len(buf), repr(buf)[2:-1], min_ex()
)
self.log(t, 6)
if not self.probing:
self.process()
continue
if self.probing < time.time():
self.log("probe ok; starting announcements", 2)
self.probing = 0
def stop(self, panic=False) -> None:
self.running = False
if not panic:
for srv in self.srv.values():
srv.sck.sendto(srv.bp_bye, (srv.grp, 5353))
def eat(self, buf: bytes, addr: tuple[str, int]):
cip = addr[0]
if cip.startswith("fe80") or cip.startswith("169.254"):
return
v6 = ":" in cip
cache = self.rx6 if v6 else self.rx4
if buf in cache.c:
return
cache.add(buf)
srv: Optional[MDNS_Sck] = self.map_client(cip) # type: ignore
if not srv:
return
now = time.time()
if self.args.zmv:
self.log("[{}] \033[36m{} \033[0m|{}|".format(srv.ip, cip, len(buf)), "90")
p = DNSRecord.parse(buf)
if self.args.zmvv:
self.log(str(p))
# check for incoming probes for our hostname
cips = [U(x.rdata) for x in p.auth if U(x.rname).lower() == self.hn]
if cips and self.sips.isdisjoint(cips):
if not [x for x in cips if x not in ("::1", "127.0.0.1")]:
# avahi broadcasting 127.0.0.1-only packets
return
self.log("someone trying to steal our hostname: {}".format(cips), 3)
# immediately unicast
if not self.probing:
srv.sck.sendto(srv.bp_ip, (cip, 5353))
# and schedule multicast
self.defend[srv] = self.defend.get(srv, now + 0.1)
return
# check for someone rejecting our probe / hijacking our hostname
cips = [
U(x.rdata)
for x in p.rr
if U(x.rname).lower() == self.hn and x.rclass == DC.F_IN
]
if cips and self.sips.isdisjoint(cips):
if not [x for x in cips if x not in ("::1", "127.0.0.1")]:
# avahi broadcasting 127.0.0.1-only packets
return
t = "mdns zeroconf: "
if self.probing:
t += "Cannot start; hostname '{}' is occupied"
else:
t += "Emergency stop; hostname '{}' got stolen"
t += "! Use --name to set another hostname.\n\nName taken by {}\n\nYour IPs: {}\n"
self.log(t.format(self.args.name, cips, list(self.sips)), 1)
self.stop(True)
return
# then a/aaaa records
for r in p.questions:
if U(r.qname).lower() != self.hn:
continue
# gvfs keeps repeating itself
found = False
for r in p.rr:
rname = U(r.rname).lower()
if rname == self.hn and r.ttl > 60:
found = True
break
if not found:
self.q[cip] = (0, srv, srv.bp_ip)
return
deadline = now + (0.5 if p.header.tc else 0.02) # rfc-7.2
# and service queries
for r in p.questions:
qname = U(r.qname).lower()
if qname in self.svcs or qname == "_services._dns-sd._udp.local.":
self.q[cip] = (deadline, srv, srv.bp_svc)
break
# heed rfc-7.1 if there was an announce in the past 12sec
# (workaround gvfs race-condition where it occasionally
# doesn't read/decode the full response...)
if now < srv.last_tx + 12:
for r in p.rr:
rdata = U(r.rdata).lower()
if rdata in self.sfqdns:
if r.ttl > 2250:
self.q.pop(cip, None)
break
def process(self) -> None:
tx = set()
now = time.time()
cooldown = 0.9 # rfc-6: 1
if self.unsolicited and self.unsolicited[0] < now:
self.unsolicited.pop(0)
cooldown = 0.1
for srv in self.srv.values():
tx.add(srv)
for srv, deadline in list(self.defend.items()):
if now < deadline:
continue
if self._tx(srv, srv.bp_ip, 0.02): # rfc-6: 0.25
self.defend.pop(srv)
for cip, (deadline, srv, msg) in list(self.q.items()):
if now < deadline:
continue
self.q.pop(cip)
self._tx(srv, msg, cooldown)
for srv in tx:
self._tx(srv, srv.bp_svc, cooldown)
def _tx(self, srv: MDNS_Sck, msg: bytes, cooldown: float) -> bool:
now = time.time()
if now < srv.last_tx + cooldown:
return False
srv.sck.sendto(msg, (srv.grp, 5353))
srv.last_tx = now
return True

252
copyparty/multicast.py Normal file
View file

@ -0,0 +1,252 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
import socket
import time
import ipaddress
from ipaddress import IPv4Network, IPv6Network, IPv4Address, IPv6Address
from .__init__ import TYPE_CHECKING
from .util import min_ex, spack
if TYPE_CHECKING:
from .svchub import SvcHub
if True: # pylint: disable=using-constant-test
from typing import Optional, Union
if not hasattr(socket, "IPPROTO_IPV6"):
setattr(socket, "IPPROTO_IPV6", 41)
class MC_Sck(object):
"""there is one socket for each server ip"""
def __init__(
self,
sck: socket.socket,
idx: int,
grp: str,
ip: str,
net: Union[IPv4Network, IPv6Network],
):
self.sck = sck
self.idx = idx
self.grp = grp
self.mreq = b""
self.ip = ip
self.net = net
self.ips = {ip: net}
self.v6 = ":" in ip
class MCast(object):
def __init__(
self, hub: "SvcHub", Srv: type[MC_Sck], mc_grp_4: str, mc_grp_6: str, port: int
) -> None:
"""disable ipv%d by setting mc_grp_%d empty"""
self.hub = hub
self.Srv = Srv
self.args = hub.args
self.asrv = hub.asrv
self.log_func = hub.log
self.grp4 = mc_grp_4
self.grp6 = mc_grp_6
self.port = port
self.srv: dict[socket.socket, MC_Sck] = {} # listening sockets
self.sips: set[str] = set() # all listening ips
self.b2srv: dict[bytes, MC_Sck] = {} # binary-ip -> server socket
self.b4: list[bytes] = [] # sorted list of binary-ips
self.b6: list[bytes] = [] # sorted list of binary-ips
self.cscache: dict[str, Optional[MC_Sck]] = {} # client ip -> server cache
def log(self, msg: str, c: Union[int, str] = 0) -> None:
self.log_func("multicast", msg, c)
def create_servers(self) -> list[str]:
bound: list[str] = []
ips = [x[0] for x in self.hub.tcpsrv.bound]
ips = list(set(ips))
if "::" in ips:
ips = [x for x in ips if x != "::"] + list(
[x.split("/")[0] for x in self.hub.tcpsrv.netdevs if ":" in x]
)
ips.append("0.0.0.0")
if "0.0.0.0" in ips:
ips = [x for x in ips if x != "0.0.0.0"] + list(
[x.split("/")[0] for x in self.hub.tcpsrv.netdevs if ":" not in x]
)
ips = [x for x in ips if x not in ("::1", "127.0.0.1")]
ips = [
[x for x in self.hub.tcpsrv.netdevs if x.startswith(y + "/")][0]
for y in ips
]
if not self.grp4:
ips = [x for x in ips if ":" in x]
if not self.grp6:
ips = [x for x in ips if ":" not in x]
if not ips:
raise Exception("no server IP matches the mdns config")
for ip in ips:
v6 = ":" in ip
netdev = "?"
try:
netdev = self.hub.tcpsrv.netdevs[ip].split(",")[0]
idx = socket.if_nametoindex(netdev)
except:
idx = socket.INADDR_ANY
t = "using INADDR_ANY for ip [{}], netdev [{}]"
if not self.srv and ip not in ["::", "0.0.0.0"]:
self.log(t.format(ip, netdev), 3)
ipv = socket.AF_INET6 if v6 else socket.AF_INET
sck = socket.socket(ipv, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sck.settimeout(None)
sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except:
pass
net = ipaddress.ip_network(ip, False)
ip = ip.split("/")[0]
srv = self.Srv(sck, idx, self.grp6 if ":" in ip else self.grp4, ip, net)
try:
self.setup_socket(srv)
self.srv[sck] = srv
bound.append(ip)
except:
self.log("announce failed on [{}]:\n{}".format(ip, min_ex()))
if self.args.zm_msub:
for s1 in self.srv.values():
for s2 in self.srv.values():
if s1.idx != s2.idx:
continue
if s1.ip not in s2.ips:
s2.ips[s1.ip] = s1.net
if self.args.zm_mnic:
for s1 in self.srv.values():
for s2 in self.srv.values():
for ip1, net1 in list(s1.ips.items()):
for ip2, net2 in list(s2.ips.items()):
if net1 == net2 and ip1 != ip2:
s1.ips[ip2] = net2
self.sips = set([x.ip for x in self.srv.values()])
return bound
def setup_socket(self, srv: MC_Sck) -> None:
sck = srv.sck
if srv.v6:
if self.args.zmv:
self.log("v6({}) idx({})".format(srv.ip, srv.idx), 6)
bip = socket.inet_pton(socket.AF_INET6, srv.ip)
self.b2srv[bip] = srv
self.b6.append(bip)
sck.bind((self.grp6 if srv.idx else "", self.port, 0, srv.idx))
bgrp = socket.inet_pton(socket.AF_INET6, self.grp6)
dev = spack(b"@I", srv.idx)
srv.mreq = bgrp + dev
if srv.idx != socket.INADDR_ANY:
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, dev)
self.hop(srv)
try:
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, 1)
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255)
except:
pass # macos
else:
if self.args.zmv:
self.log("v4({}) idx({})".format(srv.ip, srv.idx), 6)
bip = socket.inet_aton(srv.ip)
self.b2srv[bip] = srv
self.b4.append(bip)
sck.bind((self.grp4 if srv.idx else "", self.port))
bgrp = socket.inet_aton(self.grp4)
dev = (
spack(b"=I", socket.INADDR_ANY)
if srv.idx == socket.INADDR_ANY
else socket.inet_aton(srv.ip)
)
srv.mreq = bgrp + dev
if srv.idx != socket.INADDR_ANY:
sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, dev)
self.hop(srv)
try:
sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 1)
sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 255)
except:
pass
self.b4.sort(reverse=True)
self.b6.sort(reverse=True)
def hop(self, srv: MC_Sck) -> None:
"""rejoin to keepalive on routers/switches without igmp-snooping"""
sck = srv.sck
req = srv.mreq
if ":" in srv.ip:
try:
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_LEAVE_GROUP, req)
# linux does leaves/joins twice with 0.2~1.05s spacing
time.sleep(1.2)
except:
pass
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, req)
else:
try:
sck.setsockopt(socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, req)
time.sleep(1.2)
except:
pass
sck.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, req)
def map_client(self, cip: str) -> Optional[MC_Sck]:
try:
return self.cscache[cip]
except:
pass
ret: Optional[MC_Sck] = None
v6 = ":" in cip
ci = IPv6Address(cip) if v6 else IPv4Address(cip)
for x in self.b6 if v6 else self.b4:
srv = self.b2srv[x]
if any([x for x in srv.ips.values() if ci in x]):
ret = srv
break
if not ret and cip in ("127.0.0.1", "::1"):
# just give it something
ret = list(self.srv.values())[0]
if not ret:
t = "could not map client {} to known subnet; maybe forwarded from another network?"
self.log(t.format(cip), 3)
self.cscache[cip] = ret
if len(self.cscache) > 9000:
self.cscache = {}
return ret

View file

@ -0,0 +1,5 @@
`dnslib` but heavily simplified/feature-stripped
L: MIT
Copyright (c) 2010 - 2017 Paul Chakravarti
https://github.com/paulc/dnslib/

View file

@ -0,0 +1,11 @@
# coding: utf-8
"""
L: MIT
Copyright (c) 2010 - 2017 Paul Chakravarti
https://github.com/paulc/dnslib/tree/0.9.23
"""
from .dns import *
version = "0.9.23"

View file

@ -0,0 +1,41 @@
# coding: utf-8
import types
class BimapError(Exception):
pass
class Bimap(object):
def __init__(self, name, forward, error=AttributeError):
self.name = name
self.error = error
self.forward = forward.copy()
self.reverse = dict([(v, k) for (k, v) in list(forward.items())])
def get(self, k, default=None):
try:
return self.forward[k]
except KeyError:
return default or str(k)
def __getitem__(self, k):
try:
return self.forward[k]
except KeyError:
if isinstance(self.error, types.FunctionType):
return self.error(self.name, k, True)
else:
raise self.error("%s: Invalid forward lookup: [%s]" % (self.name, k))
def __getattr__(self, k):
try:
if k == "__wrapped__":
raise AttributeError()
return self.reverse[k]
except KeyError:
if isinstance(self.error, types.FunctionType):
return self.error(self.name, k, False)
else:
raise self.error("%s: Invalid reverse lookup: [%s]" % (self.name, k))

View file

@ -0,0 +1,15 @@
# coding: utf-8
from __future__ import print_function
def get_bits(data, offset, bits=1):
mask = ((1 << bits) - 1) << offset
return (data & mask) >> offset
def set_bits(data, value, offset, bits=1):
mask = ((1 << bits) - 1) << offset
clear = 0xFFFF ^ mask
data = (data & clear) | ((value << offset) & mask)
return data

View file

@ -0,0 +1,56 @@
# coding: utf-8
import binascii
import struct
class BufferError(Exception):
pass
class Buffer(object):
def __init__(self, data=b""):
self.data = bytearray(data)
self.offset = 0
def remaining(self):
return len(self.data) - self.offset
def get(self, length):
if length > self.remaining():
raise BufferError(
"Not enough bytes [offset=%d,remaining=%d,requested=%d]"
% (self.offset, self.remaining(), length)
)
start = self.offset
end = self.offset + length
self.offset += length
return bytes(self.data[start:end])
def hex(self):
return binascii.hexlify(self.data)
def pack(self, fmt, *args):
self.offset += struct.calcsize(fmt)
self.data += struct.pack(fmt, *args)
def append(self, s):
self.offset += len(s)
self.data += s
def update(self, ptr, fmt, *args):
s = struct.pack(fmt, *args)
self.data[ptr : ptr + len(s)] = s
def unpack(self, fmt):
try:
data = self.get(struct.calcsize(fmt))
return struct.unpack(fmt, data)
except struct.error:
raise BufferError(
"Error unpacking struct '%s' <%s>"
% (fmt, binascii.hexlify(data).decode())
)
def __len__(self):
return len(self.data)

View file

@ -0,0 +1,781 @@
# coding: utf-8
from __future__ import print_function
import binascii
import random
from itertools import chain
from .bit import get_bits, set_bits
from .bimap import Bimap, BimapError
from .buffer import BufferError
from .label import DNSLabel, DNSBuffer
from .ranges import H, I, IP4, IP6, check_bytes
class DNSError(Exception):
pass
def unknown_qtype(name, key, forward):
if forward:
try:
return "TYPE%d" % (key,)
except:
raise DNSError("%s: Invalid forward lookup: [%s]" % (name, key))
else:
if key.startswith("TYPE"):
try:
return int(key[4:])
except:
pass
raise DNSError("%s: Invalid reverse lookup: [%s]" % (name, key))
QTYPE = Bimap(
"QTYPE",
{1: "A", 12: "PTR", 16: "TXT", 28: "AAAA", 33: "SRV", 47: "NSEC", 255: "ANY"},
unknown_qtype,
)
CLASS = Bimap("CLASS", {1: "IN", 254: "None", 255: "*"}, DNSError)
QR = Bimap("QR", {0: "QUERY", 1: "RESPONSE"}, DNSError)
RCODE = Bimap(
"RCODE",
{
0: "NOERROR",
1: "FORMERR",
2: "SERVFAIL",
3: "NXDOMAIN",
4: "NOTIMP",
5: "REFUSED",
6: "YXDOMAIN",
7: "YXRRSET",
8: "NXRRSET",
9: "NOTAUTH",
10: "NOTZONE",
},
DNSError,
)
OPCODE = Bimap(
"OPCODE", {0: "QUERY", 1: "IQUERY", 2: "STATUS", 4: "NOTIFY", 5: "UPDATE"}, DNSError
)
def label(label, origin=None):
if label.endswith("."):
return DNSLabel(label)
else:
return (origin if isinstance(origin, DNSLabel) else DNSLabel(origin)).add(label)
class DNSRecord(object):
@classmethod
def parse(cls, packet):
buffer = DNSBuffer(packet)
try:
header = DNSHeader.parse(buffer)
questions = []
rr = []
auth = []
ar = []
for i in range(header.q):
questions.append(DNSQuestion.parse(buffer))
for i in range(header.a):
rr.append(RR.parse(buffer))
for i in range(header.auth):
auth.append(RR.parse(buffer))
for i in range(header.ar):
ar.append(RR.parse(buffer))
return cls(header, questions, rr, auth=auth, ar=ar)
except (BufferError, BimapError) as e:
raise DNSError(
"Error unpacking DNSRecord [offset=%d]: %s" % (buffer.offset, e)
)
@classmethod
def question(cls, qname, qtype="A", qclass="IN"):
return DNSRecord(
q=DNSQuestion(qname, getattr(QTYPE, qtype), getattr(CLASS, qclass))
)
def __init__(
self, header=None, questions=None, rr=None, q=None, a=None, auth=None, ar=None
):
self.header = header or DNSHeader()
self.questions = questions or []
self.rr = rr or []
self.auth = auth or []
self.ar = ar or []
if q:
self.questions.append(q)
if a:
self.rr.append(a)
self.set_header_qa()
def reply(self, ra=1, aa=1):
return DNSRecord(
DNSHeader(id=self.header.id, bitmap=self.header.bitmap, qr=1, ra=ra, aa=aa),
q=self.q,
)
def add_question(self, *q):
self.questions.extend(q)
self.set_header_qa()
def add_answer(self, *rr):
self.rr.extend(rr)
self.set_header_qa()
def add_auth(self, *auth):
self.auth.extend(auth)
self.set_header_qa()
def add_ar(self, *ar):
self.ar.extend(ar)
self.set_header_qa()
def set_header_qa(self):
self.header.q = len(self.questions)
self.header.a = len(self.rr)
self.header.auth = len(self.auth)
self.header.ar = len(self.ar)
def get_q(self):
return self.questions[0] if self.questions else DNSQuestion()
q = property(get_q)
def get_a(self):
return self.rr[0] if self.rr else RR()
a = property(get_a)
def pack(self):
self.set_header_qa()
buffer = DNSBuffer()
self.header.pack(buffer)
for q in self.questions:
q.pack(buffer)
for rr in self.rr:
rr.pack(buffer)
for auth in self.auth:
auth.pack(buffer)
for ar in self.ar:
ar.pack(buffer)
return buffer.data
def truncate(self):
return DNSRecord(DNSHeader(id=self.header.id, bitmap=self.header.bitmap, tc=1))
def format(self, prefix="", sort=False):
s = sorted if sort else lambda x: x
sections = [repr(self.header)]
sections.extend(s([repr(q) for q in self.questions]))
sections.extend(s([repr(rr) for rr in self.rr]))
sections.extend(s([repr(rr) for rr in self.auth]))
sections.extend(s([repr(rr) for rr in self.ar]))
return prefix + ("\n" + prefix).join(sections)
short = format
def __repr__(self):
return self.format()
__str__ = __repr__
class DNSHeader(object):
id = H("id")
bitmap = H("bitmap")
q = H("q")
a = H("a")
auth = H("auth")
ar = H("ar")
@classmethod
def parse(cls, buffer):
try:
(id, bitmap, q, a, auth, ar) = buffer.unpack("!HHHHHH")
return cls(id, bitmap, q, a, auth, ar)
except (BufferError, BimapError) as e:
raise DNSError(
"Error unpacking DNSHeader [offset=%d]: %s" % (buffer.offset, e)
)
def __init__(self, id=None, bitmap=None, q=0, a=0, auth=0, ar=0, **args):
if id is None:
self.id = random.randint(0, 65535)
else:
self.id = id
if bitmap is None:
self.bitmap = 0
self.rd = 1
else:
self.bitmap = bitmap
self.q = q
self.a = a
self.auth = auth
self.ar = ar
for k, v in args.items():
if k.lower() == "qr":
self.qr = v
elif k.lower() == "opcode":
self.opcode = v
elif k.lower() == "aa":
self.aa = v
elif k.lower() == "tc":
self.tc = v
elif k.lower() == "rd":
self.rd = v
elif k.lower() == "ra":
self.ra = v
elif k.lower() == "z":
self.z = v
elif k.lower() == "ad":
self.ad = v
elif k.lower() == "cd":
self.cd = v
elif k.lower() == "rcode":
self.rcode = v
def get_qr(self):
return get_bits(self.bitmap, 15)
def set_qr(self, val):
self.bitmap = set_bits(self.bitmap, val, 15)
qr = property(get_qr, set_qr)
def get_opcode(self):
return get_bits(self.bitmap, 11, 4)
def set_opcode(self, val):
self.bitmap = set_bits(self.bitmap, val, 11, 4)
opcode = property(get_opcode, set_opcode)
def get_aa(self):
return get_bits(self.bitmap, 10)
def set_aa(self, val):
self.bitmap = set_bits(self.bitmap, val, 10)
aa = property(get_aa, set_aa)
def get_tc(self):
return get_bits(self.bitmap, 9)
def set_tc(self, val):
self.bitmap = set_bits(self.bitmap, val, 9)
tc = property(get_tc, set_tc)
def get_rd(self):
return get_bits(self.bitmap, 8)
def set_rd(self, val):
self.bitmap = set_bits(self.bitmap, val, 8)
rd = property(get_rd, set_rd)
def get_ra(self):
return get_bits(self.bitmap, 7)
def set_ra(self, val):
self.bitmap = set_bits(self.bitmap, val, 7)
ra = property(get_ra, set_ra)
def get_z(self):
return get_bits(self.bitmap, 6)
def set_z(self, val):
self.bitmap = set_bits(self.bitmap, val, 6)
z = property(get_z, set_z)
def get_ad(self):
return get_bits(self.bitmap, 5)
def set_ad(self, val):
self.bitmap = set_bits(self.bitmap, val, 5)
ad = property(get_ad, set_ad)
def get_cd(self):
return get_bits(self.bitmap, 4)
def set_cd(self, val):
self.bitmap = set_bits(self.bitmap, val, 4)
cd = property(get_cd, set_cd)
def get_rcode(self):
return get_bits(self.bitmap, 0, 4)
def set_rcode(self, val):
self.bitmap = set_bits(self.bitmap, val, 0, 4)
rcode = property(get_rcode, set_rcode)
def pack(self, buffer):
buffer.pack("!HHHHHH", self.id, self.bitmap, self.q, self.a, self.auth, self.ar)
def __repr__(self):
f = [
self.aa and "AA",
self.tc and "TC",
self.rd and "RD",
self.ra and "RA",
self.z and "Z",
self.ad and "AD",
self.cd and "CD",
]
if OPCODE.get(self.opcode) == "UPDATE":
f1 = "zo"
f2 = "pr"
f3 = "up"
f4 = "ad"
else:
f1 = "q"
f2 = "a"
f3 = "ns"
f4 = "ar"
return (
"<DNS Header: id=0x%x type=%s opcode=%s flags=%s "
"rcode='%s' %s=%d %s=%d %s=%d %s=%d>"
% (
self.id,
QR.get(self.qr),
OPCODE.get(self.opcode),
",".join(filter(None, f)),
RCODE.get(self.rcode),
f1,
self.q,
f2,
self.a,
f3,
self.auth,
f4,
self.ar,
)
)
__str__ = __repr__
class DNSQuestion(object):
@classmethod
def parse(cls, buffer):
try:
qname = buffer.decode_name()
qtype, qclass = buffer.unpack("!HH")
return cls(qname, qtype, qclass)
except (BufferError, BimapError) as e:
raise DNSError(
"Error unpacking DNSQuestion [offset=%d]: %s" % (buffer.offset, e)
)
def __init__(self, qname=None, qtype=1, qclass=1):
self.qname = qname
self.qtype = qtype
self.qclass = qclass
def set_qname(self, qname):
if isinstance(qname, DNSLabel):
self._qname = qname
else:
self._qname = DNSLabel(qname)
def get_qname(self):
return self._qname
qname = property(get_qname, set_qname)
def pack(self, buffer):
buffer.encode_name(self.qname)
buffer.pack("!HH", self.qtype, self.qclass)
def __repr__(self):
return "<DNS Question: '%s' qtype=%s qclass=%s>" % (
self.qname,
QTYPE.get(self.qtype),
CLASS.get(self.qclass),
)
__str__ = __repr__
class RR(object):
rtype = H("rtype")
rclass = H("rclass")
ttl = I("ttl")
rdlength = H("rdlength")
@classmethod
def parse(cls, buffer):
try:
rname = buffer.decode_name()
rtype, rclass, ttl, rdlength = buffer.unpack("!HHIH")
if rdlength:
rdata = RDMAP.get(QTYPE.get(rtype), RD).parse(buffer, rdlength)
else:
rdata = ""
return cls(rname, rtype, rclass, ttl, rdata)
except (BufferError, BimapError) as e:
raise DNSError("Error unpacking RR [offset=%d]: %s" % (buffer.offset, e))
def __init__(self, rname=None, rtype=1, rclass=1, ttl=0, rdata=None):
self.rname = rname
self.rtype = rtype
self.rclass = rclass
self.ttl = ttl
self.rdata = rdata
def set_rname(self, rname):
if isinstance(rname, DNSLabel):
self._rname = rname
else:
self._rname = DNSLabel(rname)
def get_rname(self):
return self._rname
rname = property(get_rname, set_rname)
def pack(self, buffer):
buffer.encode_name(self.rname)
buffer.pack("!HHI", self.rtype, self.rclass, self.ttl)
rdlength_ptr = buffer.offset
buffer.pack("!H", 0)
start = buffer.offset
self.rdata.pack(buffer)
end = buffer.offset
buffer.update(rdlength_ptr, "!H", end - start)
def __repr__(self):
return "<DNS RR: '%s' rtype=%s rclass=%s ttl=%d rdata='%s'>" % (
self.rname,
QTYPE.get(self.rtype),
CLASS.get(self.rclass),
self.ttl,
self.rdata,
)
__str__ = __repr__
class RD(object):
@classmethod
def parse(cls, buffer, length):
try:
data = buffer.get(length)
return cls(data)
except (BufferError, BimapError) as e:
raise DNSError("Error unpacking RD [offset=%d]: %s" % (buffer.offset, e))
def __init__(self, data=b""):
check_bytes("data", data)
self.data = bytes(data)
def pack(self, buffer):
buffer.append(self.data)
def __repr__(self):
if len(self.data) > 0:
return "\\# %d %s" % (
len(self.data),
binascii.hexlify(self.data).decode().upper(),
)
else:
return "\\# 0"
attrs = ("data",)
def _force_bytes(x):
if isinstance(x, bytes):
return x
else:
return x.encode()
class TXT(RD):
@classmethod
def parse(cls, buffer, length):
try:
data = list()
start_bo = buffer.offset
now_length = 0
while buffer.offset < start_bo + length:
(txtlength,) = buffer.unpack("!B")
if now_length + txtlength < length:
now_length += txtlength
data.append(buffer.get(txtlength))
else:
raise DNSError(
"Invalid TXT record: len(%d) > RD len(%d)" % (txtlength, length)
)
return cls(data)
except (BufferError, BimapError) as e:
raise DNSError("Error unpacking TXT [offset=%d]: %s" % (buffer.offset, e))
def __init__(self, data):
if type(data) in (tuple, list):
self.data = [_force_bytes(x) for x in data]
else:
self.data = [_force_bytes(data)]
if any([len(x) > 255 for x in self.data]):
raise DNSError("TXT record too long: %s" % self.data)
def pack(self, buffer):
for ditem in self.data:
if len(ditem) > 255:
raise DNSError("TXT record too long: %s" % ditem)
buffer.pack("!B", len(ditem))
buffer.append(ditem)
def __repr__(self):
return ",".join([repr(x) for x in self.data])
class A(RD):
data = IP4("data")
@classmethod
def parse(cls, buffer, length):
try:
data = buffer.unpack("!BBBB")
return cls(data)
except (BufferError, BimapError) as e:
raise DNSError("Error unpacking A [offset=%d]: %s" % (buffer.offset, e))
def __init__(self, data):
if type(data) in (tuple, list):
self.data = tuple(data)
else:
self.data = tuple(map(int, data.rstrip(".").split(".")))
def pack(self, buffer):
buffer.pack("!BBBB", *self.data)
def __repr__(self):
return "%d.%d.%d.%d" % self.data
def _parse_ipv6(a):
l, _, r = a.partition("::")
l_groups = list(chain(*[divmod(int(x, 16), 256) for x in l.split(":") if x]))
r_groups = list(chain(*[divmod(int(x, 16), 256) for x in r.split(":") if x]))
zeros = [0] * (16 - len(l_groups) - len(r_groups))
return tuple(l_groups + zeros + r_groups)
def _format_ipv6(a):
left = []
right = []
current = "left"
for i in range(0, 16, 2):
group = (a[i] << 8) + a[i + 1]
if current == "left":
if group == 0 and i < 14:
if (a[i + 2] << 8) + a[i + 3] == 0:
current = "right"
else:
left.append("0")
else:
left.append("%x" % group)
else:
if group == 0 and len(right) == 0:
pass
else:
right.append("%x" % group)
if len(left) < 8:
return ":".join(left) + "::" + ":".join(right)
else:
return ":".join(left)
class AAAA(RD):
data = IP6("data")
@classmethod
def parse(cls, buffer, length):
try:
data = buffer.unpack("!16B")
return cls(data)
except (BufferError, BimapError) as e:
raise DNSError("Error unpacking AAAA [offset=%d]: %s" % (buffer.offset, e))
def __init__(self, data):
if type(data) in (tuple, list):
self.data = tuple(data)
else:
self.data = _parse_ipv6(data)
def pack(self, buffer):
buffer.pack("!16B", *self.data)
def __repr__(self):
return _format_ipv6(self.data)
class CNAME(RD):
@classmethod
def parse(cls, buffer, length):
try:
label = buffer.decode_name()
return cls(label)
except (BufferError, BimapError) as e:
raise DNSError("Error unpacking CNAME [offset=%d]: %s" % (buffer.offset, e))
def __init__(self, label=None):
self.label = label
def set_label(self, label):
if isinstance(label, DNSLabel):
self._label = label
else:
self._label = DNSLabel(label)
def get_label(self):
return self._label
label = property(get_label, set_label)
def pack(self, buffer):
buffer.encode_name(self.label)
def __repr__(self):
return "%s" % (self.label)
attrs = ("label",)
class PTR(CNAME):
pass
class SRV(RD):
priority = H("priority")
weight = H("weight")
port = H("port")
@classmethod
def parse(cls, buffer, length):
try:
priority, weight, port = buffer.unpack("!HHH")
target = buffer.decode_name()
return cls(priority, weight, port, target)
except (BufferError, BimapError) as e:
raise DNSError("Error unpacking SRV [offset=%d]: %s" % (buffer.offset, e))
def __init__(self, priority=0, weight=0, port=0, target=None):
self.priority = priority
self.weight = weight
self.port = port
self.target = target
def set_target(self, target):
if isinstance(target, DNSLabel):
self._target = target
else:
self._target = DNSLabel(target)
def get_target(self):
return self._target
target = property(get_target, set_target)
def pack(self, buffer):
buffer.pack("!HHH", self.priority, self.weight, self.port)
buffer.encode_name(self.target)
def __repr__(self):
return "%d %d %d %s" % (self.priority, self.weight, self.port, self.target)
attrs = ("priority", "weight", "port", "target")
def decode_type_bitmap(type_bitmap):
rrlist = []
buf = DNSBuffer(type_bitmap)
while buf.remaining():
winnum, winlen = buf.unpack("BB")
bitmap = bytearray(buf.get(winlen))
for (pos, value) in enumerate(bitmap):
for i in range(8):
if (value << i) & 0x80:
bitpos = (256 * winnum) + (8 * pos) + i
rrlist.append(QTYPE[bitpos])
return rrlist
def encode_type_bitmap(rrlist):
rrlist = sorted([getattr(QTYPE, rr) for rr in rrlist])
buf = DNSBuffer()
curWindow = rrlist[0] // 256
bitmap = bytearray(32)
n = len(rrlist) - 1
for i, rr in enumerate(rrlist):
v = rr - curWindow * 256
bitmap[v // 8] |= 1 << (7 - v % 8)
if i == n or rrlist[i + 1] >= (curWindow + 1) * 256:
while bitmap[-1] == 0:
bitmap = bitmap[:-1]
buf.pack("BB", curWindow, len(bitmap))
buf.append(bitmap)
if i != n:
curWindow = rrlist[i + 1] // 256
bitmap = bytearray(32)
return buf.data
class NSEC(RD):
@classmethod
def parse(cls, buffer, length):
try:
end = buffer.offset + length
name = buffer.decode_name()
rrlist = decode_type_bitmap(buffer.get(end - buffer.offset))
return cls(name, rrlist)
except (BufferError, BimapError) as e:
raise DNSError("Error unpacking NSEC [offset=%d]: %s" % (buffer.offset, e))
def __init__(self, label, rrlist):
self.label = label
self.rrlist = rrlist
def set_label(self, label):
if isinstance(label, DNSLabel):
self._label = label
else:
self._label = DNSLabel(label)
def get_label(self):
return self._label
label = property(get_label, set_label)
def pack(self, buffer):
buffer.encode_name_nocompress(self.label)
buffer.append(encode_type_bitmap(self.rrlist))
def __repr__(self):
return "%s %s" % (self.label, " ".join(self.rrlist))
attrs = ("label", "rrlist")
RDMAP = {"A": A, "AAAA": AAAA, "TXT": TXT, "PTR": PTR, "SRV": SRV, "NSEC": NSEC}

View file

@ -0,0 +1,154 @@
# coding: utf-8
from __future__ import print_function
import fnmatch, re
from .bit import get_bits, set_bits
from .buffer import Buffer, BufferError
LDH = set(range(33, 127))
ESCAPE = re.compile(r"\\([0-9][0-9][0-9])")
class DNSLabelError(Exception):
pass
class DNSLabel(object):
def __init__(self, label):
if type(label) == DNSLabel:
self.label = label.label
elif type(label) in (list, tuple):
self.label = tuple(label)
else:
if not label or label in (b".", "."):
self.label = ()
elif type(label) is not bytes:
if type("") != type(b""):
label = ESCAPE.sub(lambda m: chr(int(m[1])), label)
self.label = tuple(label.encode("idna").rstrip(b".").split(b"."))
else:
if type("") == type(b""):
label = ESCAPE.sub(lambda m: chr(int(m.groups()[0])), label)
self.label = tuple(label.rstrip(b".").split(b"."))
def add(self, name):
new = DNSLabel(name)
if self.label:
new.label += self.label
return new
def idna(self):
return ".".join([s.decode("idna") for s in self.label]) + "."
def _decode(self, s):
if set(s).issubset(LDH):
return s.decode()
else:
return "".join([(chr(c) if (c in LDH) else "\\%03d" % c) for c in s])
def __str__(self):
return ".".join([self._decode(bytearray(s)) for s in self.label]) + "."
def __repr__(self):
return "<DNSLabel: '%s'>" % str(self)
def __hash__(self):
return hash(tuple(map(lambda x: x.lower(), self.label)))
def __ne__(self, other):
return not self == other
def __eq__(self, other):
if type(other) != DNSLabel:
return self.__eq__(DNSLabel(other))
else:
return [l.lower() for l in self.label] == [l.lower() for l in other.label]
def __len__(self):
return len(b".".join(self.label))
class DNSBuffer(Buffer):
def __init__(self, data=b""):
super(DNSBuffer, self).__init__(data)
self.names = {}
def decode_name(self, last=-1):
label = []
done = False
while not done:
(length,) = self.unpack("!B")
if get_bits(length, 6, 2) == 3:
self.offset -= 1
pointer = get_bits(self.unpack("!H")[0], 0, 14)
save = self.offset
if last == save:
raise BufferError(
"Recursive pointer in DNSLabel [offset=%d,pointer=%d,length=%d]"
% (self.offset, pointer, len(self.data))
)
if pointer < self.offset:
self.offset = pointer
else:
raise BufferError(
"Invalid pointer in DNSLabel [offset=%d,pointer=%d,length=%d]"
% (self.offset, pointer, len(self.data))
)
label.extend(self.decode_name(save).label)
self.offset = save
done = True
else:
if length > 0:
l = self.get(length)
try:
l.decode()
except UnicodeDecodeError:
raise BufferError("Invalid label <%s>" % l)
label.append(l)
else:
done = True
return DNSLabel(label)
def encode_name(self, name):
if not isinstance(name, DNSLabel):
name = DNSLabel(name)
if len(name) > 253:
raise DNSLabelError("Domain label too long: %r" % name)
name = list(name.label)
while name:
if tuple(name) in self.names:
pointer = self.names[tuple(name)]
pointer = set_bits(pointer, 3, 14, 2)
self.pack("!H", pointer)
return
else:
self.names[tuple(name)] = self.offset
element = name.pop(0)
if len(element) > 63:
raise DNSLabelError("Label component too long: %r" % element)
self.pack("!B", len(element))
self.append(element)
self.append(b"\x00")
def encode_name_nocompress(self, name):
if not isinstance(name, DNSLabel):
name = DNSLabel(name)
if len(name) > 253:
raise DNSLabelError("Domain label too long: %r" % name)
name = list(name.label)
while name:
element = name.pop(0)
if len(element) > 63:
raise DNSLabelError("Label component too long: %r" % element)
self.pack("!B", len(element))
self.append(element)
self.append(b"\x00")

View file

@ -0,0 +1,105 @@
# coding: utf-8
from __future__ import print_function
import collections
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
class Lexer(object):
escape_chars = "\\"
escape = {"n": "\n", "t": "\t", "r": "\r"}
def __init__(self, f, debug=False):
if hasattr(f, "read"):
self.f = f
elif type(f) == str:
self.f = StringIO(f)
elif type(f) == bytes:
self.f = StringIO(f.decode())
else:
raise ValueError("Invalid input")
self.debug = debug
self.q = collections.deque()
self.state = self.lexStart
self.escaped = False
self.eof = False
def __iter__(self):
return self.parse()
def next_token(self):
if self.debug:
print("STATE", self.state)
(tok, self.state) = self.state()
return tok
def parse(self):
while self.state is not None and not self.eof:
tok = self.next_token()
if tok:
yield tok
def read(self, n=1):
s = ""
while self.q and n > 0:
s += self.q.popleft()
n -= 1
s += self.f.read(n)
if s == "":
self.eof = True
if self.debug:
print("Read: >%s<" % repr(s))
return s
def peek(self, n=1):
s = ""
i = 0
while len(self.q) > i and n > 0:
s += self.q[i]
i += 1
n -= 1
r = self.f.read(n)
if n > 0 and r == "":
self.eof = True
self.q.extend(r)
if self.debug:
print("Peek : >%s<" % repr(s + r))
return s + r
def pushback(self, s):
p = collections.deque(s)
p.extend(self.q)
self.q = p
def readescaped(self):
c = self.read(1)
if c in self.escape_chars:
self.escaped = True
n = self.peek(3)
if n.isdigit():
n = self.read(3)
if self.debug:
print("Escape: >%s<" % n)
return chr(int(n, 8))
elif n[0] in "x":
x = self.read(3)
if self.debug:
print("Escape: >%s<" % x)
return chr(int(x[1:], 16))
else:
c = self.read(1)
if self.debug:
print("Escape: >%s<" % c)
return self.escape.get(c, c)
else:
self.escaped = False
return c
def lexStart(self):
return (None, None)

View file

@ -0,0 +1,81 @@
# coding: utf-8
import sys
if sys.version < "3":
int_types = (
int,
long,
)
byte_types = (str, bytearray)
else:
int_types = (int,)
byte_types = (bytes, bytearray)
def check_instance(name, val, types):
if not isinstance(val, types):
raise ValueError(
"Attribute '%s' must be instance of %s [%s]" % (name, types, type(val))
)
def check_bytes(name, val):
return check_instance(name, val, byte_types)
def range_property(attr, min, max):
def getter(obj):
return getattr(obj, "_%s" % attr)
def setter(obj, val):
if isinstance(val, int_types) and min <= val <= max:
setattr(obj, "_%s" % attr, val)
else:
raise ValueError(
"Attribute '%s' must be between %d-%d [%s]" % (attr, min, max, val)
)
return property(getter, setter)
def B(attr):
return range_property(attr, 0, 255)
def H(attr):
return range_property(attr, 0, 65535)
def I(attr):
return range_property(attr, 0, 4294967295)
def ntuple_range(attr, n, min, max):
f = lambda x: isinstance(x, int_types) and min <= x <= max
def getter(obj):
return getattr(obj, "_%s" % attr)
def setter(obj, val):
if len(val) != n:
raise ValueError(
"Attribute '%s' must be tuple with %d elements [%s]" % (attr, n, val)
)
if all(map(f, val)):
setattr(obj, "_%s" % attr, val)
else:
raise ValueError(
"Attribute '%s' elements must be between %d-%d [%s]"
% (attr, min, max, val)
)
return property(getter, setter)
def IP4(attr):
return ntuple_range(attr, 4, 0, 255)
def IP6(attr):
return ntuple_range(attr, 16, 0, 255)

View file

@ -0,0 +1,5 @@
`ifaddr` with py2.7 support enabled by make-sfx.sh which strips py3 hints using strip_hints and removes the `^if True:` blocks
L: BSD-2-Clause
Copyright (c) 2014 Stefan C. Mueller
https://github.com/pydron/ifaddr/

View file

@ -0,0 +1,21 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
"""
L: BSD-2-Clause
Copyright (c) 2014 Stefan C. Mueller
https://github.com/pydron/ifaddr/tree/0.2.0
"""
import os
from ._shared import Adapter, IP
if os.name == "nt":
from ._win32 import get_adapters
elif os.name == "posix":
from ._posix import get_adapters
else:
raise RuntimeError("Unsupported Operating System: %s" % os.name)
__all__ = ["Adapter", "IP", "get_adapters"]

View file

@ -0,0 +1,83 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
import os
import ctypes.util
import ipaddress
import collections
import socket
if True: # pylint: disable=using-constant-test
from typing import Iterable, Optional
from . import _shared as shared
from ._shared import U
class ifaddrs(ctypes.Structure):
pass
ifaddrs._fields_ = [
("ifa_next", ctypes.POINTER(ifaddrs)),
("ifa_name", ctypes.c_char_p),
("ifa_flags", ctypes.c_uint),
("ifa_addr", ctypes.POINTER(shared.sockaddr)),
("ifa_netmask", ctypes.POINTER(shared.sockaddr)),
]
libc = ctypes.CDLL(ctypes.util.find_library("socket" if os.uname()[0] == "SunOS" else "c"), use_errno=True) # type: ignore
def get_adapters(include_unconfigured: bool = False) -> Iterable[shared.Adapter]:
addr0 = addr = ctypes.POINTER(ifaddrs)()
retval = libc.getifaddrs(ctypes.byref(addr))
if retval != 0:
eno = ctypes.get_errno()
raise OSError(eno, os.strerror(eno))
ips = collections.OrderedDict()
def add_ip(adapter_name: str, ip: Optional[shared.IP]) -> None:
if adapter_name not in ips:
index = None # type: Optional[int]
try:
# Mypy errors on this when the Windows CI runs:
# error: Module has no attribute "if_nametoindex"
index = socket.if_nametoindex(adapter_name) # type: ignore
except (OSError, AttributeError):
pass
ips[adapter_name] = shared.Adapter(
adapter_name, adapter_name, [], index=index
)
if ip is not None:
ips[adapter_name].ips.append(ip)
while addr:
name = addr[0].ifa_name.decode(encoding="UTF-8")
ip_addr = shared.sockaddr_to_ip(addr[0].ifa_addr)
if ip_addr:
if addr[0].ifa_netmask and not addr[0].ifa_netmask[0].sa_familiy:
addr[0].ifa_netmask[0].sa_familiy = addr[0].ifa_addr[0].sa_familiy
netmask = shared.sockaddr_to_ip(addr[0].ifa_netmask)
if isinstance(netmask, tuple):
netmaskStr = U(netmask[0])
prefixlen = shared.ipv6_prefixlength(ipaddress.IPv6Address(netmaskStr))
else:
if netmask is None:
t = "sockaddr_to_ip({}) returned None"
raise Exception(t.format(addr[0].ifa_netmask))
netmaskStr = U("0.0.0.0/" + netmask)
prefixlen = ipaddress.IPv4Network(netmaskStr).prefixlen
ip = shared.IP(ip_addr, prefixlen, name)
add_ip(name, ip)
else:
if include_unconfigured:
add_ip(name, None)
addr = addr[0].ifa_next
libc.freeifaddrs(addr0)
return ips.values()

View file

@ -0,0 +1,202 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
import sys
import ctypes
import socket
import ipaddress
import platform
if True: # pylint: disable=using-constant-test
from typing import List, Optional, Tuple, Union, Callable
PY2 = sys.version_info < (3,)
if not PY2:
U: Callable[[str], str] = str
else:
U = unicode # noqa: F821 # pylint: disable=undefined-variable,self-assigning-variable
class Adapter(object):
"""
Represents a network interface device controller (NIC), such as a
network card. An adapter can have multiple IPs.
On Linux aliasing (multiple IPs per physical NIC) is implemented
by creating 'virtual' adapters, each represented by an instance
of this class. Each of those 'virtual' adapters can have both
a IPv4 and an IPv6 IP address.
"""
def __init__(
self, name: str, nice_name: str, ips: List["IP"], index: Optional[int] = None
) -> None:
#: Unique name that identifies the adapter in the system.
#: On Linux this is of the form of `eth0` or `eth0:1`, on
#: Windows it is a UUID in string representation, such as
#: `{846EE342-7039-11DE-9D20-806E6F6E6963}`.
self.name = name
#: Human readable name of the adpater. On Linux this
#: is currently the same as :attr:`name`. On Windows
#: this is the name of the device.
self.nice_name = nice_name
#: List of :class:`ifaddr.IP` instances in the order they were
#: reported by the system.
self.ips = ips
#: Adapter index as used by some API (e.g. IPv6 multicast group join).
self.index = index
def __repr__(self) -> str:
return "Adapter(name={name}, nice_name={nice_name}, ips={ips}, index={index})".format(
name=repr(self.name),
nice_name=repr(self.nice_name),
ips=repr(self.ips),
index=repr(self.index),
)
if True:
# Type of an IPv4 address (a string in "xxx.xxx.xxx.xxx" format)
_IPv4Address = str
# Type of an IPv6 address (a three-tuple `(ip, flowinfo, scope_id)`)
_IPv6Address = tuple[str, int, int]
class IP(object):
"""
Represents an IP address of an adapter.
"""
def __init__(
self, ip: Union[_IPv4Address, _IPv6Address], network_prefix: int, nice_name: str
) -> None:
#: IP address. For IPv4 addresses this is a string in
#: "xxx.xxx.xxx.xxx" format. For IPv6 addresses this
#: is a three-tuple `(ip, flowinfo, scope_id)`, where
#: `ip` is a string in the usual collon separated
#: hex format.
self.ip = ip
#: Number of bits of the IP that represent the
#: network. For a `255.255.255.0` netmask, this
#: number would be `24`.
self.network_prefix = network_prefix
#: Human readable name for this IP.
#: On Linux is this currently the same as the adapter name.
#: On Windows this is the name of the network connection
#: as configured in the system control panel.
self.nice_name = nice_name
@property
def is_IPv4(self) -> bool:
"""
Returns `True` if this IP is an IPv4 address and `False`
if it is an IPv6 address.
"""
return not isinstance(self.ip, tuple)
@property
def is_IPv6(self) -> bool:
"""
Returns `True` if this IP is an IPv6 address and `False`
if it is an IPv4 address.
"""
return isinstance(self.ip, tuple)
def __repr__(self) -> str:
return "IP(ip={ip}, network_prefix={network_prefix}, nice_name={nice_name})".format(
ip=repr(self.ip),
network_prefix=repr(self.network_prefix),
nice_name=repr(self.nice_name),
)
if platform.system() == "Darwin" or "BSD" in platform.system():
# BSD derived systems use marginally different structures
# than either Linux or Windows.
# I still keep it in `shared` since we can use
# both structures equally.
class sockaddr(ctypes.Structure):
_fields_ = [
("sa_len", ctypes.c_uint8),
("sa_familiy", ctypes.c_uint8),
("sa_data", ctypes.c_uint8 * 14),
]
class sockaddr_in(ctypes.Structure):
_fields_ = [
("sa_len", ctypes.c_uint8),
("sa_familiy", ctypes.c_uint8),
("sin_port", ctypes.c_uint16),
("sin_addr", ctypes.c_uint8 * 4),
("sin_zero", ctypes.c_uint8 * 8),
]
class sockaddr_in6(ctypes.Structure):
_fields_ = [
("sa_len", ctypes.c_uint8),
("sa_familiy", ctypes.c_uint8),
("sin6_port", ctypes.c_uint16),
("sin6_flowinfo", ctypes.c_uint32),
("sin6_addr", ctypes.c_uint8 * 16),
("sin6_scope_id", ctypes.c_uint32),
]
else:
class sockaddr(ctypes.Structure): # type: ignore
_fields_ = [("sa_familiy", ctypes.c_uint16), ("sa_data", ctypes.c_uint8 * 14)]
class sockaddr_in(ctypes.Structure): # type: ignore
_fields_ = [
("sin_familiy", ctypes.c_uint16),
("sin_port", ctypes.c_uint16),
("sin_addr", ctypes.c_uint8 * 4),
("sin_zero", ctypes.c_uint8 * 8),
]
class sockaddr_in6(ctypes.Structure): # type: ignore
_fields_ = [
("sin6_familiy", ctypes.c_uint16),
("sin6_port", ctypes.c_uint16),
("sin6_flowinfo", ctypes.c_uint32),
("sin6_addr", ctypes.c_uint8 * 16),
("sin6_scope_id", ctypes.c_uint32),
]
def sockaddr_to_ip(
sockaddr_ptr: "ctypes.pointer[sockaddr]",
) -> Optional[Union[_IPv4Address, _IPv6Address]]:
if sockaddr_ptr:
if sockaddr_ptr[0].sa_familiy == socket.AF_INET:
ipv4 = ctypes.cast(sockaddr_ptr, ctypes.POINTER(sockaddr_in))
ippacked = bytes(bytearray(ipv4[0].sin_addr))
ip = U(ipaddress.ip_address(ippacked))
return ip
elif sockaddr_ptr[0].sa_familiy == socket.AF_INET6:
ipv6 = ctypes.cast(sockaddr_ptr, ctypes.POINTER(sockaddr_in6))
flowinfo = ipv6[0].sin6_flowinfo
ippacked = bytes(bytearray(ipv6[0].sin6_addr))
ip = U(ipaddress.ip_address(ippacked))
scope_id = ipv6[0].sin6_scope_id
return (ip, flowinfo, scope_id)
return None
def ipv6_prefixlength(address: ipaddress.IPv6Address) -> int:
prefix_length = 0
for i in range(address.max_prefixlen):
if int(address) >> i & 1:
prefix_length = prefix_length + 1
return prefix_length

View file

@ -0,0 +1,135 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
import ctypes
from ctypes import wintypes
if True: # pylint: disable=using-constant-test
from typing import Iterable, List
from . import _shared as shared
NO_ERROR = 0
ERROR_BUFFER_OVERFLOW = 111
MAX_ADAPTER_NAME_LENGTH = 256
MAX_ADAPTER_DESCRIPTION_LENGTH = 128
MAX_ADAPTER_ADDRESS_LENGTH = 8
AF_UNSPEC = 0
class SOCKET_ADDRESS(ctypes.Structure):
_fields_ = [
("lpSockaddr", ctypes.POINTER(shared.sockaddr)),
("iSockaddrLength", wintypes.INT),
]
class IP_ADAPTER_UNICAST_ADDRESS(ctypes.Structure):
pass
IP_ADAPTER_UNICAST_ADDRESS._fields_ = [
("Length", wintypes.ULONG),
("Flags", wintypes.DWORD),
("Next", ctypes.POINTER(IP_ADAPTER_UNICAST_ADDRESS)),
("Address", SOCKET_ADDRESS),
("PrefixOrigin", ctypes.c_uint),
("SuffixOrigin", ctypes.c_uint),
("DadState", ctypes.c_uint),
("ValidLifetime", wintypes.ULONG),
("PreferredLifetime", wintypes.ULONG),
("LeaseLifetime", wintypes.ULONG),
("OnLinkPrefixLength", ctypes.c_uint8),
]
class IP_ADAPTER_ADDRESSES(ctypes.Structure):
pass
IP_ADAPTER_ADDRESSES._fields_ = [
("Length", wintypes.ULONG),
("IfIndex", wintypes.DWORD),
("Next", ctypes.POINTER(IP_ADAPTER_ADDRESSES)),
("AdapterName", ctypes.c_char_p),
("FirstUnicastAddress", ctypes.POINTER(IP_ADAPTER_UNICAST_ADDRESS)),
("FirstAnycastAddress", ctypes.c_void_p),
("FirstMulticastAddress", ctypes.c_void_p),
("FirstDnsServerAddress", ctypes.c_void_p),
("DnsSuffix", ctypes.c_wchar_p),
("Description", ctypes.c_wchar_p),
("FriendlyName", ctypes.c_wchar_p),
]
iphlpapi = ctypes.windll.LoadLibrary("Iphlpapi") # type: ignore
def enumerate_interfaces_of_adapter(
nice_name: str, address: IP_ADAPTER_UNICAST_ADDRESS
) -> Iterable[shared.IP]:
# Iterate through linked list and fill list
addresses = [] # type: List[IP_ADAPTER_UNICAST_ADDRESS]
while True:
addresses.append(address)
if not address.Next:
break
address = address.Next[0]
for address in addresses:
ip = shared.sockaddr_to_ip(address.Address.lpSockaddr)
if ip is None:
t = "sockaddr_to_ip({}) returned None"
raise Exception(t.format(address.Address.lpSockaddr))
network_prefix = address.OnLinkPrefixLength
yield shared.IP(ip, network_prefix, nice_name)
def get_adapters(include_unconfigured: bool = False) -> Iterable[shared.Adapter]:
# Call GetAdaptersAddresses() with error and buffer size handling
addressbuffersize = wintypes.ULONG(15 * 1024)
retval = ERROR_BUFFER_OVERFLOW
while retval == ERROR_BUFFER_OVERFLOW:
addressbuffer = ctypes.create_string_buffer(addressbuffersize.value)
retval = iphlpapi.GetAdaptersAddresses(
wintypes.ULONG(AF_UNSPEC),
wintypes.ULONG(0),
None,
ctypes.byref(addressbuffer),
ctypes.byref(addressbuffersize),
)
if retval != NO_ERROR:
raise ctypes.WinError() # type: ignore
# Iterate through adapters fill array
address_infos = [] # type: List[IP_ADAPTER_ADDRESSES]
address_info = IP_ADAPTER_ADDRESSES.from_buffer(addressbuffer)
while True:
address_infos.append(address_info)
if not address_info.Next:
break
address_info = address_info.Next[0]
# Iterate through unicast addresses
result = [] # type: List[shared.Adapter]
for adapter_info in address_infos:
# We don't expect non-ascii characters here, so encoding shouldn't matter
name = adapter_info.AdapterName.decode()
nice_name = adapter_info.Description
index = adapter_info.IfIndex
if adapter_info.FirstUnicastAddress:
ips = enumerate_interfaces_of_adapter(
adapter_info.FriendlyName, adapter_info.FirstUnicastAddress[0]
)
ips = list(ips)
result.append(shared.Adapter(name, nice_name, ips, index=index))
elif include_unconfigured:
result.append(shared.Adapter(name, nice_name, [], index=index))
return result

View file

@ -195,10 +195,17 @@ class SvcHub(object):
args.th_poke = min(args.th_poke, args.th_maxage, args.ac_maxage) args.th_poke = min(args.th_poke, args.th_maxage, args.ac_maxage)
zms = ""
if not args.https_only:
zms += "d"
if not args.http_only:
zms += "D"
if args.ftp or args.ftps: if args.ftp or args.ftps:
from .ftpd import Ftpd from .ftpd import Ftpd
self.ftpd = Ftpd(self) self.ftpd = Ftpd(self)
zms += "f" if args.ftp else "F"
if args.smb: if args.smb:
# impacket.dcerpc is noisy about listen timeouts # impacket.dcerpc is noisy about listen timeouts
@ -210,6 +217,12 @@ class SvcHub(object):
self.smbd = SMB(self) self.smbd = SMB(self)
socket.setdefaulttimeout(sto) socket.setdefaulttimeout(sto)
self.smbd.start() self.smbd.start()
zms += "s"
if not args.zms:
args.zms = zms
self.mdns: Any = None
# decide which worker impl to use # decide which worker impl to use
if self.check_mp_enable(): if self.check_mp_enable():
@ -359,6 +372,15 @@ class SvcHub(object):
def run(self) -> None: def run(self) -> None:
self.tcpsrv.run() self.tcpsrv.run()
if getattr(self.args, "zm", False):
try:
from .mdns import MDNS
self.mdns = MDNS(self)
Daemon(self.mdns.run, "mdns")
except:
self.log("root", "mdns startup failed;\n" + min_ex(), 3)
Daemon(self.thr_httpsrv_up, "sig-hsrv-up2") Daemon(self.thr_httpsrv_up, "sig-hsrv-up2")
sigs = [signal.SIGINT, signal.SIGTERM] sigs = [signal.SIGINT, signal.SIGTERM]
@ -464,6 +486,11 @@ class SvcHub(object):
ret = 1 ret = 1
try: try:
self.pr("OPYTHAT") self.pr("OPYTHAT")
slp = 0.0
if self.mdns:
Daemon(self.mdns.stop)
slp = time.time() + 1
self.tcpsrv.shutdown() self.tcpsrv.shutdown()
self.broker.shutdown() self.broker.shutdown()
self.up2k.shutdown() self.up2k.shutdown()
@ -482,6 +509,9 @@ class SvcHub(object):
Daemon(self.kill9, a=(1,)) Daemon(self.kill9, a=(1,))
self.smbd.stop() self.smbd.stop()
while time.time() < slp:
time.sleep(0.1)
self.pr("nailed it", end="") self.pr("nailed it", end="")
ret = self.retcode ret = self.retcode
except: except:

View file

@ -25,6 +25,9 @@ if True:
if TYPE_CHECKING: if TYPE_CHECKING:
from .svchub import SvcHub from .svchub import SvcHub
if not hasattr(socket, "IPPROTO_IPV6"):
setattr(socket, "IPPROTO_IPV6", 41)
class TcpSrv(object): class TcpSrv(object):
""" """
@ -42,6 +45,7 @@ class TcpSrv(object):
self.stopping = False self.stopping = False
self.srv: list[socket.socket] = [] self.srv: list[socket.socket] = []
self.bound: list[tuple[str, int]] = []
self.nsrv = 0 self.nsrv = 0
self.qr = "" self.qr = ""
pad = False pad = False
@ -97,14 +101,22 @@ class TcpSrv(object):
if pad: if pad:
self.log("tcpsrv", "") self.log("tcpsrv", "")
ip = "127.0.0.1" eps = {"127.0.0.1": "local only", "::1": "local only"}
eps = {ip: "local only"} nonlocals = [x for x in self.args.i if x not in [k.split("/")[0] for k in eps]]
nonlocals = [x for x in self.args.i if x != ip]
if nonlocals: if nonlocals:
eps = self.detect_interfaces(self.args.i) try:
self.netdevs = self.detect_interfaces(self.args.i)
except:
t = "failed to discover server IP addresses\n"
self.log("tcpsrv", t + min_ex(), 3)
self.netdevs = {}
eps.update({k.split("/")[0]: v for k, v in self.netdevs.items()})
if not eps: if not eps:
for x in nonlocals: for x in nonlocals:
eps[x] = "external" eps[x] = "external"
else:
self.netdevs = {}
qr1: dict[str, list[int]] = {} qr1: dict[str, list[int]] = {}
qr2: dict[str, list[int]] = {} qr2: dict[str, list[int]] = {}
@ -180,6 +192,12 @@ class TcpSrv(object):
srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
srv.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) srv.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
srv.settimeout(None) # < does not inherit, ^ does srv.settimeout(None) # < does not inherit, ^ does
try:
srv.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False)
except:
pass # will create another ipv4 socket instead
try: try:
srv.bind((ip, port)) srv.bind((ip, port))
self.srv.append(srv) self.srv.append(srv)
@ -194,8 +212,8 @@ class TcpSrv(object):
def run(self) -> None: def run(self) -> None:
all_eps = [x.getsockname()[:2] for x in self.srv] all_eps = [x.getsockname()[:2] for x in self.srv]
bound = [] bound: list[tuple[str, int]] = []
srvs = [] srvs: list[socket.socket] = []
for srv in self.srv: for srv in self.srv:
ip, port = srv.getsockname()[:2] ip, port = srv.getsockname()[:2]
try: try:
@ -225,6 +243,7 @@ class TcpSrv(object):
self.hub.broker.say("listen", srv) self.hub.broker.say("listen", srv)
self.srv = srvs self.srv = srvs
self.bound = bound
self.nsrv = len(srvs) self.nsrv = len(srvs)
def shutdown(self) -> None: def shutdown(self) -> None:
@ -370,19 +389,22 @@ class TcpSrv(object):
return eps return eps
def detect_interfaces(self, listen_ips: list[str]) -> dict[str, str]: def detect_interfaces(self, listen_ips: list[str]) -> dict[str, str]:
if MACOS: from .stolen.ifaddr import get_adapters
eps = self.ips_macos()
elif ANYWIN: nics = get_adapters(True)
eps, off = self.ips_windows_ipconfig() # sees more interfaces + link state eps = {}
eps.update(self.ips_windows_netsh()) # has better names for nic in nics:
for k, v in eps.items(): for nip in nic.ips:
if v in off: ipa = nip.ip[0] if ":" in str(nip.ip) else nip.ip
eps[k] += ", \033[31mLINK-DOWN" sip = "{}/{}".format(ipa, nip.network_prefix)
else: if sip.startswith("fe80") or sip.startswith("169.254"):
eps = self.ips_linux() # browsers dont impl linklocal
continue
eps[sip] = nic.nice_name
if "0.0.0.0" not in listen_ips and "::" not in listen_ips: if "0.0.0.0" not in listen_ips and "::" not in listen_ips:
eps = {k: v for k, v in eps.items() if k in listen_ips} eps = {k: v for k, v in eps.items() if k.split("/")[0] in listen_ips}
try: try:
ext_devs = list(self._extdevs_nix()) ext_devs = list(self._extdevs_nix())
@ -478,7 +500,13 @@ class TcpSrv(object):
def _qr(self, t1: dict[str, list[int]], t2: dict[str, list[int]]) -> str: def _qr(self, t1: dict[str, list[int]], t2: dict[str, list[int]]) -> str:
ip = None ip = None
for ip in list(t1) + list(t2): ips = list(t1) + list(t2)
if self.args.zm:
name = self.args.name + ".local"
t1[name] = next(v for v in (t1 or t2).values())
ips = [name] + ips
for ip in ips:
if ip.startswith(self.args.qri): if ip.startswith(self.args.qri):
break break
ip = "" ip = ""

View file

@ -24,6 +24,7 @@ import time
import traceback import traceback
from collections import Counter from collections import Counter
from datetime import datetime from datetime import datetime
from ipaddress import IPv6Address
from queue import Queue from queue import Queue
@ -60,11 +61,6 @@ try:
except: except:
pass pass
try:
from ipaddress import IPv6Address
except:
pass
try: try:
HAVE_SQLITE3 = True HAVE_SQLITE3 = True
import sqlite3 # pylint: disable=unused-import # typechk import sqlite3 # pylint: disable=unused-import # typechk
@ -184,6 +180,9 @@ IMPLICATIONS = [
["smbw", "smb"], ["smbw", "smb"],
["smb1", "smb"], ["smb1", "smb"],
["smb_dbg", "smb"], ["smb_dbg", "smb"],
["zmvv", "zmv"],
["zmv", "zm"],
["zms", "zm"],
] ]
@ -536,6 +535,27 @@ class _LUnrecv(object):
Unrecv = _Unrecv Unrecv = _Unrecv
class CachedSet(object):
def __init__(self, maxage: float) -> None:
self.c: dict[Any, float] = {}
self.maxage = maxage
self.oldest = 0.0
def add(self, v: Any) -> None:
self.c[v] = time.time()
def cln(self) -> None:
now = time.time()
if now - self.oldest < self.maxage:
return
c = self.c = {k: v for k, v in self.c.items() if now - v < self.maxage}
try:
self.oldest = c[min(c, key=c.get)]
except:
self.oldest = now
class FHC(object): class FHC(object):
class CE(object): class CE(object):
def __init__(self, fh: typing.BinaryIO) -> None: def __init__(self, fh: typing.BinaryIO) -> None:
@ -836,7 +856,7 @@ class Garda(object):
if not self.lim: if not self.lim:
return 0, ip return 0, ip
if ":" in ip and not PY2: if ":" in ip:
# assume /64 clients; drop 4 groups # assume /64 clients; drop 4 groups
ip = IPv6Address(ip).exploded[:-20] ip = IPv6Address(ip).exploded[:-20]
@ -1603,7 +1623,7 @@ def exclude_dotfiles(filepaths: list[str]) -> list[str]:
return [x for x in filepaths if not x.split("/")[-1].startswith(".")] return [x for x in filepaths if not x.split("/")[-1].startswith(".")]
def _ipnorm3(ip: str) -> str: def ipnorm(ip: str) -> str:
if ":" in ip: if ":" in ip:
# assume /64 clients; drop 4 groups # assume /64 clients; drop 4 groups
return IPv6Address(ip).exploded[:-20] return IPv6Address(ip).exploded[:-20]
@ -1611,9 +1631,6 @@ def _ipnorm3(ip: str) -> str:
return ip return ip
ipnorm = _ipnorm3 if not PY2 else unicode
def http_ts(ts: int) -> str: def http_ts(ts: int) -> str:
file_dt = datetime.utcfromtimestamp(ts) file_dt = datetime.utcfromtimestamp(ts)
return file_dt.strftime(HTTP_TS_FMT) return file_dt.strftime(HTTP_TS_FMT)

View file

@ -48,6 +48,17 @@ hashwasm would solve the streaming issue but reduces hashing speed for sha512 (x
* blake2 might be a better choice since xxh is non-cryptographic, but that gets ~15 MiB/s on slower androids * blake2 might be a better choice since xxh is non-cryptographic, but that gets ~15 MiB/s on slower androids
## assumptions
### mdns
* outgoing replies will always fit in one packet
* if a client mentions any of our services, assume it's not missing any
* always answer with all services, even if the client only asked for a few
* not-impl: probe tiebreaking (too complicated)
* not-impl: unicast listen (assume avahi took it)
# sfx repack # sfx repack
reduce the size of an sfx by removing features reduce the size of an sfx by removing features

View file

@ -6,17 +6,21 @@ L: MIT
https://github.com/pallets/jinja/ https://github.com/pallets/jinja/
C: 2007 Pallets C: 2007 Pallets
L: BSD 3-Clause L: BSD 3-Clause
https://github.com/pallets/markupsafe/ https://github.com/pallets/markupsafe/
C: 2010 Pallets C: 2010 Pallets
L: BSD 3-Clause L: BSD 3-Clause
https://github.com/paulc/dnslib/
C: 2010-2017 Paul Chakravarti
L: BSD 2-Clause
https://github.com/giampaolo/pyftpdlib/ https://github.com/giampaolo/pyftpdlib/
C: 2007 Giampaolo Rodola' C: 2007 Giampaolo Rodola
L: MIT L: MIT
https://github.com/nayuki/QR-Code-generator https://github.com/nayuki/QR-Code-generator/
C: Project Nayuki C: Project Nayuki
L: MIT L: MIT

View file

@ -18,6 +18,12 @@ f=../build/isc.txt
awk '/div>/{o=0}o>2;o{o++}/;OWNER/{o=1}' | awk '/div>/{o=0}o>2;o{o++}/;OWNER/{o=1}' |
awk '{gsub(/<[^>]+>/,"")};/./{b=0}!/./{b++}b>1{next}1' >$f awk '{gsub(/<[^>]+>/,"")};/./{b=0}!/./{b++}b>1{next}1' >$f
f=../build/2bsd.txt
[ -e $f ] ||
curl https://opensource.org/licenses/BSD-2-Clause |
awk '/div>/{o=0}o>1;o{o++}/HOLDER/{o=1}' |
awk '{gsub(/<[^>]+>/,"")};1' >$f
f=../build/3bsd.txt f=../build/3bsd.txt
[ -e $f ] || [ -e $f ] ||
curl https://opensource.org/licenses/BSD-3-Clause | curl https://opensource.org/licenses/BSD-3-Clause |
@ -33,6 +39,7 @@ f=../build/ofl.txt
(sed -r 's/^L: /License: /;s/^C: /Copyright (c) /' <../docs/lics.txt (sed -r 's/^L: /License: /;s/^C: /Copyright (c) /' <../docs/lics.txt
printf '\n\n--- MIT License ---\n\n'; cat ../build/mit.txt printf '\n\n--- MIT License ---\n\n'; cat ../build/mit.txt
printf '\n\n--- ISC License ---\n\n'; cat ../build/isc.txt printf '\n\n--- ISC License ---\n\n'; cat ../build/isc.txt
printf '\n\n--- BSD 2-Clause License ---\n\n'; cat ../build/2bsd.txt
printf '\n\n--- BSD 3-Clause License ---\n\n'; cat ../build/3bsd.txt printf '\n\n--- BSD 3-Clause License ---\n\n'; cat ../build/3bsd.txt
printf '\n\n--- SIL Open Font License v1.1 ---\n\n'; cat ../build/ofl.txt printf '\n\n--- SIL Open Font License v1.1 ---\n\n'; cat ../build/ofl.txt
) | ) |

View file

@ -27,6 +27,8 @@ help() { exec cat <<'EOF'
# #
# `no-smb` saves ~3.5k by removing the smb / cifs server # `no-smb` saves ~3.5k by removing the smb / cifs server
# #
# `no-zm` saves ~k by removing the zeroconf mDNS server
#
# _____________________________________________________________________ # _____________________________________________________________________
# web features: # web features:
# #
@ -101,6 +103,7 @@ while [ ! -z "$1" ]; do
gzz) shift;use_gzz=$1;use_gz=1; ;; gzz) shift;use_gzz=$1;use_gz=1; ;;
no-ftp) no_ftp=1 ; ;; no-ftp) no_ftp=1 ; ;;
no-smb) no_smb=1 ; ;; no-smb) no_smb=1 ; ;;
no-zm) no_zm=1 ; ;;
no-fnt) no_fnt=1 ; ;; no-fnt) no_fnt=1 ; ;;
no-hl) no_hl=1 ; ;; no-hl) no_hl=1 ; ;;
no-dd) no_dd=1 ; ;; no-dd) no_dd=1 ; ;;
@ -136,11 +139,22 @@ tmpdir="$(
[ $repack ] && { [ $repack ] && {
old="$tmpdir/pe-copyparty.$(id -u)" old="$tmpdir/pe-copyparty.$(id -u)"
echo "repack of files in $old" echo "repack of files in $old"
cp -pR "$old/"*{py2,j2,copyparty} . cp -pR "$old/"*{py2,py37,j2,copyparty} .
cp -pR "$old/"*ftp . || true cp -pR "$old/"*ftp . || true
} }
[ $repack ] || { [ $repack ] || {
echo collecting ipaddress
f="../build/ipaddress-1.0.23.tar.gz"
[ -e "$f" ] ||
(url=https://files.pythonhosted.org/packages/b9/9a/3e9da40ea28b8210dd6504d3fe9fe7e013b62bf45902b458d1cdc3c34ed9/ipaddress-1.0.23.tar.gz;
wget -O$f "$url" || curl -L "$url" >$f)
tar -zxf $f
mkdir py37
mv ipaddress-*/ipaddress.py py37/
rm -rf ipaddress-*
echo collecting jinja2 echo collecting jinja2
f="../build/Jinja2-2.11.3.tar.gz" f="../build/Jinja2-2.11.3.tar.gz"
[ -e "$f" ] || [ -e "$f" ] ||
@ -237,6 +251,8 @@ tmpdir="$(
awk 'NR<4||NR>27;NR==4{print"# license: https://opensource.org/licenses/ISC\n"}' ../build/$n >copyparty/vend/$n awk 'NR<4||NR>27;NR==4{print"# license: https://opensource.org/licenses/ISC\n"}' ../build/$n >copyparty/vend/$n
done done
rm -f copyparty/stolen/*/README.md
# remove type hints before build instead # remove type hints before build instead
(cd copyparty; "$pybin" ../../scripts/strip_hints/a.py; rm uh) (cd copyparty; "$pybin" ../../scripts/strip_hints/a.py; rm uh)
@ -322,6 +338,10 @@ rm have
rm -f copyparty/smbd.py && rm -f copyparty/smbd.py &&
sed -ri '/add_argument\("--smb/d' copyparty/__main__.py sed -ri '/add_argument\("--smb/d' copyparty/__main__.py
[ $no_zm ] &&
rm -rf copyparty/mdns.py copyparty/stolen/dnslib &&
sed -ri '/add_argument\("--zm/d' copyparty/__main__.py
[ $no_cm ] && { [ $no_cm ] && {
rm -rf copyparty/web/mde.* copyparty/web/deps/easymde* rm -rf copyparty/web/mde.* copyparty/web/deps/easymde*
echo h > copyparty/web/mde.html echo h > copyparty/web/mde.html
@ -464,7 +484,7 @@ nf=$(ls -1 "$zdir"/arc.* | wc -l)
echo "copying.txt 404 pls rebuild" echo "copying.txt 404 pls rebuild"
mv ftp/* j2/* copyparty/vend/* . mv ftp/* j2/* copyparty/vend/* .
rm -rf ftp j2 py2 copyparty/vend rm -rf ftp j2 py2 py37 copyparty/vend
(cd copyparty; tar -cvf z.tar $t; rm -rf $t) (cd copyparty; tar -cvf z.tar $t; rm -rf $t)
cd .. cd ..
pyoxidizer build --release --target-triple $tgt pyoxidizer build --release --target-triple $tgt
@ -481,7 +501,7 @@ nf=$(ls -1 "$zdir"/arc.* | wc -l)
echo gen tarlist echo gen tarlist
for d in copyparty j2 py2 ftp; do find $d -type f; done | # strip_hints for d in copyparty j2 py2 py37 ftp; do find $d -type f; done | # strip_hints
sed -r 's/(.*)\.(.*)/\2 \1/' | LC_ALL=C sort | sed -r 's/(.*)\.(.*)/\2 \1/' | LC_ALL=C sort |
sed -r 's/([^ ]*) (.*)/\2.\1/' | grep -vE '/list1?$' > list1 sed -r 's/([^ ]*) (.*)/\2.\1/' | grep -vE '/list1?$' > list1

View file

@ -18,6 +18,7 @@ copyparty/httpcli.py,
copyparty/httpconn.py, copyparty/httpconn.py,
copyparty/httpsrv.py, copyparty/httpsrv.py,
copyparty/ico.py, copyparty/ico.py,
copyparty/mdns.py,
copyparty/mtag.py, copyparty/mtag.py,
copyparty/res, copyparty/res,
copyparty/res/COPYING.txt, copyparty/res/COPYING.txt,
@ -26,6 +27,20 @@ copyparty/smbd.py,
copyparty/star.py, copyparty/star.py,
copyparty/stolen, copyparty/stolen,
copyparty/stolen/__init__.py, copyparty/stolen/__init__.py,
copyparty/stolen/dnslib,
copyparty/stolen/dnslib/__init__.py,
copyparty/stolen/dnslib/bimap.py,
copyparty/stolen/dnslib/bit.py,
copyparty/stolen/dnslib/buffer.py,
copyparty/stolen/dnslib/dns.py,
copyparty/stolen/dnslib/label.py,
copyparty/stolen/dnslib/lex.py,
copyparty/stolen/dnslib/ranges.py,
copyparty/stolen/ifaddr,
copyparty/stolen/ifaddr/__init__.py,
copyparty/stolen/ifaddr/_posix.py,
copyparty/stolen/ifaddr/_shared.py,
copyparty/stolen/ifaddr/_win32.py,
copyparty/stolen/qrcodegen.py, copyparty/stolen/qrcodegen.py,
copyparty/stolen/surrogateescape.py, copyparty/stolen/surrogateescape.py,
copyparty/sutil.py, copyparty/sutil.py,

View file

@ -28,6 +28,7 @@ CKSUM = None
STAMP = None STAMP = None
PY2 = sys.version_info < (3,) PY2 = sys.version_info < (3,)
PY37 = sys.version_info > (3, 7)
WINDOWS = sys.platform in ["win32", "msys"] WINDOWS = sys.platform in ["win32", "msys"]
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
me = os.path.abspath(os.path.realpath(__file__)) me = os.path.abspath(os.path.realpath(__file__))
@ -401,7 +402,7 @@ def run(tmp, j2, ftp):
t.daemon = True t.daemon = True
t.start() t.start()
ld = (("", ""), (j2, "j2"), (ftp, "ftp"), (not PY2, "py2")) ld = (("", ""), (j2, "j2"), (ftp, "ftp"), (not PY2, "py2"), (PY37, "py37"))
ld = [os.path.join(tmp, b) for a, b in ld if not a] ld = [os.path.join(tmp, b) for a, b in ld if not a]
# skip 1 # skip 1