mirror of
https://github.com/9001/copyparty.git
synced 2025-08-17 09:02:15 -06:00
556 lines
18 KiB
Python
556 lines
18 KiB
Python
# coding: utf-8
|
|
from __future__ import print_function, unicode_literals
|
|
|
|
import errno
|
|
import random
|
|
import select
|
|
import socket
|
|
import time
|
|
|
|
from ipaddress import IPv4Network, IPv6Network
|
|
|
|
from .__init__ import TYPE_CHECKING
|
|
from .__init__ import unicode as U
|
|
from .multicast import MC_Sck, MCast
|
|
from .stolen.dnslib import AAAA
|
|
from .stolen.dnslib import CLASS as DC
|
|
from .stolen.dnslib import (
|
|
NSEC,
|
|
PTR,
|
|
QTYPE,
|
|
RR,
|
|
SRV,
|
|
TXT,
|
|
A,
|
|
DNSHeader,
|
|
DNSQuestion,
|
|
DNSRecord,
|
|
)
|
|
from .util import CachedSet, Daemon, Netdev, list_ips, min_ex
|
|
|
|
if TYPE_CHECKING:
|
|
from .svchub import SvcHub
|
|
|
|
if True: # pylint: disable=using-constant-test
|
|
from typing import Any, Optional, Union
|
|
|
|
|
|
MDNS4 = "224.0.0.251"
|
|
MDNS6 = "ff02::fb"
|
|
|
|
|
|
class MDNS_Sck(MC_Sck):
|
|
def __init__(
|
|
self,
|
|
sck: socket.socket,
|
|
nd: Netdev,
|
|
grp: str,
|
|
ip: str,
|
|
net: Union[IPv4Network, IPv6Network],
|
|
):
|
|
super(MDNS_Sck, self).__init__(sck, nd, grp, ip, net)
|
|
|
|
self.bp_probe = b""
|
|
self.bp_ip = b""
|
|
self.bp_svc = b""
|
|
self.bp_bye = b""
|
|
|
|
self.last_tx = 0.0
|
|
self.tx_ex = False
|
|
|
|
|
|
class MDNS(MCast):
|
|
def __init__(self, hub: "SvcHub", ngen: int) -> None:
|
|
al = hub.args
|
|
grp4 = "" if al.zm6 else MDNS4
|
|
grp6 = "" if al.zm4 else MDNS6
|
|
super(MDNS, self).__init__(
|
|
hub, MDNS_Sck, al.zm_on, al.zm_off, grp4, grp6, 5353, hub.args.zmv
|
|
)
|
|
self.srv: dict[socket.socket, MDNS_Sck] = {}
|
|
self.logsrc = "mDNS-{}".format(ngen)
|
|
self.ngen = ngen
|
|
self.ttl = 300
|
|
|
|
zs = self.args.name + ".local."
|
|
zs = zs.encode("ascii", "replace").decode("ascii", "replace")
|
|
self.hn = "-".join(x for x in zs.split("?") if x) or (
|
|
"vault-{}".format(random.randint(1, 255))
|
|
)
|
|
self.lhn = self.hn.lower()
|
|
|
|
# requester ip -> (response deadline, srv, body):
|
|
self.q: dict[str, tuple[float, MDNS_Sck, bytes]] = {}
|
|
self.rx4 = CachedSet(0.42) # 3 probes @ 250..500..750 => 500ms span
|
|
self.rx6 = CachedSet(0.42)
|
|
self.svcs, self.sfqdns = self.build_svcs()
|
|
self.lsvcs = {k.lower(): v for k, v in self.svcs.items()}
|
|
self.lsfqdns = set([x.lower() for x in self.sfqdns])
|
|
|
|
self.probing = 0.0
|
|
self.unsolicited: list[float] = [] # scheduled announces on all nics
|
|
self.defend: dict[MDNS_Sck, float] = {} # server -> deadline
|
|
|
|
def log(self, msg: str, c: Union[int, str] = 0) -> None:
|
|
self.log_func(self.logsrc, msg, c)
|
|
|
|
def build_svcs(self) -> tuple[dict[str, dict[str, Any]], set[str]]:
|
|
zms = self.args.zms
|
|
http = {"port": 80 if 80 in self.args.p else self.args.p[0]}
|
|
https = {"port": 443 if 443 in self.args.p else self.args.p[0]}
|
|
webdav = http.copy()
|
|
webdavs = https.copy()
|
|
webdav["u"] = webdavs["u"] = "u" # KDE requires username
|
|
ftp = {"port": (self.args.ftp if "f" in zms else self.args.ftps)}
|
|
smb = {"port": self.args.smb_port}
|
|
|
|
# some gvfs require path
|
|
zs = self.args.zm_ld or "/"
|
|
if zs:
|
|
webdav["path"] = zs
|
|
webdavs["path"] = zs
|
|
|
|
if self.args.zm_lh:
|
|
http["path"] = self.args.zm_lh
|
|
https["path"] = self.args.zm_lh
|
|
|
|
if self.args.zm_lf:
|
|
ftp["path"] = self.args.zm_lf
|
|
|
|
if self.args.zm_ls:
|
|
smb["path"] = self.args.zm_ls
|
|
|
|
svcs: dict[str, dict[str, Any]] = {}
|
|
|
|
if "d" in zms:
|
|
svcs["_webdav._tcp.local."] = webdav
|
|
|
|
if "D" in zms:
|
|
svcs["_webdavs._tcp.local."] = webdavs
|
|
|
|
if "h" in zms:
|
|
svcs["_http._tcp.local."] = http
|
|
|
|
if "H" in zms:
|
|
svcs["_https._tcp.local."] = https
|
|
|
|
if "f" in zms.lower():
|
|
svcs["_ftp._tcp.local."] = ftp
|
|
|
|
if "s" in zms.lower():
|
|
svcs["_smb._tcp.local."] = smb
|
|
|
|
sfqdns: set[str] = set()
|
|
for k, v in svcs.items():
|
|
name = "{}-c-{}".format(self.args.name, k.split(".")[0][1:])
|
|
v["name"] = name
|
|
sfqdns.add("{}.{}".format(name, k))
|
|
|
|
return svcs, sfqdns
|
|
|
|
def build_replies(self) -> None:
|
|
for srv in self.srv.values():
|
|
probe = DNSRecord(DNSHeader(0, 0), q=DNSQuestion(self.hn, QTYPE.ANY))
|
|
areply = DNSRecord(DNSHeader(0, 0x8400))
|
|
sreply = DNSRecord(DNSHeader(0, 0x8400))
|
|
bye = DNSRecord(DNSHeader(0, 0x8400))
|
|
|
|
have4 = have6 = False
|
|
for s2 in self.srv.values():
|
|
if srv.idx != s2.idx:
|
|
continue
|
|
|
|
if s2.v6:
|
|
have6 = True
|
|
else:
|
|
have4 = True
|
|
|
|
for ip in srv.ips:
|
|
if ":" in ip:
|
|
qt = QTYPE.AAAA
|
|
ar = {"rclass": DC.F_IN, "rdata": AAAA(ip)}
|
|
else:
|
|
qt = QTYPE.A
|
|
ar = {"rclass": DC.F_IN, "rdata": A(ip)}
|
|
|
|
r0 = RR(self.hn, qt, ttl=0, **ar)
|
|
r120 = RR(self.hn, qt, ttl=120, **ar)
|
|
# rfc-10:
|
|
# SHOULD rr ttl 120sec for A/AAAA/SRV
|
|
# (and recommend 75min for all others)
|
|
|
|
probe.add_auth(r120)
|
|
areply.add_answer(r120)
|
|
sreply.add_answer(r120)
|
|
bye.add_answer(r0)
|
|
|
|
for sclass, props in self.svcs.items():
|
|
sname = props["name"]
|
|
sport = props["port"]
|
|
sfqdn = sname + "." + sclass
|
|
|
|
k = "_services._dns-sd._udp.local."
|
|
r = RR(k, QTYPE.PTR, DC.IN, 4500, PTR(sclass))
|
|
sreply.add_answer(r)
|
|
|
|
r = RR(sclass, QTYPE.PTR, DC.IN, 4500, PTR(sfqdn))
|
|
sreply.add_answer(r)
|
|
|
|
r = RR(sfqdn, QTYPE.SRV, DC.F_IN, 120, SRV(0, 0, sport, self.hn))
|
|
sreply.add_answer(r)
|
|
areply.add_answer(r)
|
|
|
|
r = RR(sfqdn, QTYPE.SRV, DC.F_IN, 0, SRV(0, 0, sport, self.hn))
|
|
bye.add_answer(r)
|
|
|
|
txts = []
|
|
for k in ("u", "path"):
|
|
if k not in props:
|
|
continue
|
|
|
|
zb = "{}={}".format(k, props[k]).encode("utf-8")
|
|
if len(zb) > 255:
|
|
t = "value too long for mdns: [{}]"
|
|
raise Exception(t.format(props[k]))
|
|
|
|
txts.append(zb)
|
|
|
|
# gvfs really wants txt even if they're empty
|
|
r = RR(sfqdn, QTYPE.TXT, DC.F_IN, 4500, TXT(txts))
|
|
sreply.add_answer(r)
|
|
|
|
if not (have4 and have6) and not self.args.zm_noneg:
|
|
ns = NSEC(self.hn, ["AAAA" if have6 else "A"])
|
|
r = RR(self.hn, QTYPE.NSEC, DC.F_IN, 120, ns)
|
|
areply.add_ar(r)
|
|
if len(sreply.pack()) < 1400:
|
|
sreply.add_ar(r)
|
|
|
|
srv.bp_probe = probe.pack()
|
|
srv.bp_ip = areply.pack()
|
|
srv.bp_svc = sreply.pack()
|
|
srv.bp_bye = bye.pack()
|
|
|
|
# since all replies are small enough to fit in one packet,
|
|
# always send full replies rather than just a/aaaa records
|
|
srv.bp_ip = srv.bp_svc
|
|
|
|
def send_probes(self) -> None:
|
|
slp = random.random() * 0.25
|
|
for _ in range(3):
|
|
time.sleep(slp)
|
|
slp = 0.25
|
|
if not self.running:
|
|
break
|
|
|
|
if self.args.zmv:
|
|
self.log("sending hostname probe...")
|
|
|
|
# ipv4: need to probe each ip (each server)
|
|
# ipv6: only need to probe each set of looped nics
|
|
probed6: set[str] = set()
|
|
for srv in self.srv.values():
|
|
if srv.ip in probed6:
|
|
continue
|
|
|
|
try:
|
|
srv.sck.sendto(srv.bp_probe, (srv.grp, 5353))
|
|
if srv.v6:
|
|
for ip in srv.ips:
|
|
probed6.add(ip)
|
|
except Exception as ex:
|
|
self.log("sendto failed: {} ({})".format(srv.ip, ex), "90")
|
|
|
|
def run(self) -> None:
|
|
try:
|
|
bound = self.create_servers()
|
|
except:
|
|
t = "no server IP matches the mdns config\n{}"
|
|
self.log(t.format(min_ex()), 1)
|
|
bound = []
|
|
|
|
if not bound:
|
|
self.log("failed to announce copyparty services on the network", 3)
|
|
return
|
|
|
|
self.build_replies()
|
|
Daemon(self.send_probes)
|
|
zf = time.time() + 2
|
|
self.probing = zf # cant unicast so give everyone an extra sec
|
|
self.unsolicited = [zf, zf + 1, zf + 3, zf + 7] # rfc-8.3
|
|
|
|
try:
|
|
self.run2()
|
|
except OSError as ex:
|
|
if ex.errno != errno.EBADF:
|
|
raise
|
|
|
|
self.log("stopping due to {}".format(ex), "90")
|
|
|
|
self.log("stopped", 2)
|
|
|
|
def run2(self) -> None:
|
|
last_hop = time.time()
|
|
ihop = self.args.mc_hop
|
|
while self.running:
|
|
timeout = (
|
|
0.02 + random.random() * 0.07
|
|
if self.probing or self.q or self.defend
|
|
else max(0.05, self.unsolicited[0] - time.time())
|
|
if self.unsolicited
|
|
else (last_hop + ihop if ihop else 180)
|
|
)
|
|
rdy = select.select(self.srv, [], [], timeout)
|
|
rx: list[socket.socket] = rdy[0] # type: ignore
|
|
self.rx4.cln()
|
|
self.rx6.cln()
|
|
buf = b""
|
|
addr = ("0", 0)
|
|
for sck in rx:
|
|
try:
|
|
buf, addr = sck.recvfrom(4096)
|
|
self.eat(buf, addr, sck)
|
|
except:
|
|
if not self.running:
|
|
self.log("stopped", 2)
|
|
return
|
|
|
|
t = "{} {} \033[33m|{}| {}\n{}".format(
|
|
self.srv[sck].name, addr, len(buf), repr(buf)[2:-1], min_ex()
|
|
)
|
|
self.log(t, 6)
|
|
|
|
if not self.probing:
|
|
self.process()
|
|
continue
|
|
|
|
if self.probing < time.time():
|
|
t = "probe ok; announcing [{}]"
|
|
self.log(t.format(self.hn[:-1]), 2)
|
|
self.probing = 0
|
|
|
|
def stop(self, panic=False) -> None:
|
|
self.running = False
|
|
for srv in self.srv.values():
|
|
try:
|
|
if panic:
|
|
srv.sck.close()
|
|
else:
|
|
srv.sck.sendto(srv.bp_bye, (srv.grp, 5353))
|
|
except:
|
|
pass
|
|
|
|
self.srv = {}
|
|
|
|
def eat(self, buf: bytes, addr: tuple[str, int], sck: socket.socket) -> None:
|
|
cip = addr[0]
|
|
v6 = ":" in cip
|
|
if (cip.startswith("169.254") and not self.ll_ok) or (
|
|
v6 and not cip.startswith("fe80")
|
|
):
|
|
return
|
|
|
|
cache = self.rx6 if v6 else self.rx4
|
|
if buf in cache.c:
|
|
return
|
|
|
|
srv: Optional[MDNS_Sck] = self.srv[sck] if v6 else self.map_client(cip) # type: ignore
|
|
if not srv:
|
|
return
|
|
|
|
cache.add(buf)
|
|
now = time.time()
|
|
|
|
if self.args.zmv and cip != srv.ip and cip not in srv.ips:
|
|
t = "{} [{}] \033[36m{} \033[0m|{}|"
|
|
self.log(t.format(srv.name, srv.ip, cip, len(buf)), "90")
|
|
|
|
p = DNSRecord.parse(buf)
|
|
if self.args.zmvv:
|
|
self.log(str(p))
|
|
|
|
# check for incoming probes for our hostname
|
|
cips = [U(x.rdata) for x in p.auth if U(x.rname).lower() == self.lhn]
|
|
if cips and self.sips.isdisjoint(cips):
|
|
if not [x for x in cips if x not in ("::1", "127.0.0.1")]:
|
|
# avahi broadcasting 127.0.0.1-only packets
|
|
return
|
|
|
|
self.log("someone trying to steal our hostname: {}".format(cips), 3)
|
|
# immediately unicast
|
|
if not self.probing:
|
|
srv.sck.sendto(srv.bp_ip, (cip, 5353))
|
|
|
|
# and schedule multicast
|
|
self.defend[srv] = self.defend.get(srv, now + 0.1)
|
|
return
|
|
|
|
# check for someone rejecting our probe / hijacking our hostname
|
|
cips = [
|
|
U(x.rdata)
|
|
for x in p.rr
|
|
if U(x.rname).lower() == self.lhn and x.rclass == DC.F_IN
|
|
]
|
|
if cips and self.sips.isdisjoint(cips):
|
|
if not [x for x in cips if x not in ("::1", "127.0.0.1")]:
|
|
# avahi broadcasting 127.0.0.1-only packets
|
|
return
|
|
|
|
# check if we've been given additional IPs
|
|
for ip in list_ips():
|
|
if ip in cips:
|
|
self.sips.add(ip)
|
|
|
|
if not self.sips.isdisjoint(cips):
|
|
return
|
|
|
|
t = "mdns zeroconf: "
|
|
if self.probing:
|
|
t += "Cannot start; hostname '{}' is occupied"
|
|
else:
|
|
t += "Emergency stop; hostname '{}' got stolen"
|
|
|
|
t += " on {}! Use --name to set another hostname.\n\nName taken by {}\n\nYour IPs: {}\n"
|
|
self.log(t.format(self.args.name, srv.name, cips, list(self.sips)), 1)
|
|
self.stop(True)
|
|
return
|
|
|
|
# then rfc-6.7; dns pretending to be mdns (android...)
|
|
if p.header.id or addr[1] != 5353:
|
|
rsp: Optional[DNSRecord] = None
|
|
for r in p.questions:
|
|
try:
|
|
lhn = U(r.qname).lower()
|
|
except:
|
|
self.log("invalid question: {}".format(r))
|
|
continue
|
|
|
|
if lhn != self.lhn:
|
|
continue
|
|
|
|
if p.header.id and r.qtype in (QTYPE.A, QTYPE.AAAA):
|
|
rsp = rsp or DNSRecord(DNSHeader(p.header.id, 0x8400))
|
|
rsp.add_question(r)
|
|
for ip in srv.ips:
|
|
qt = r.qtype
|
|
v6 = ":" in ip
|
|
if v6 == (qt == QTYPE.AAAA):
|
|
rd = AAAA(ip) if v6 else A(ip)
|
|
rr = RR(self.hn, qt, DC.IN, 10, rd)
|
|
rsp.add_answer(rr)
|
|
if rsp:
|
|
srv.sck.sendto(rsp.pack(), addr[:2])
|
|
# but don't return in case it's a differently broken client
|
|
|
|
# then a/aaaa records
|
|
for r in p.questions:
|
|
try:
|
|
lhn = U(r.qname).lower()
|
|
except:
|
|
self.log("invalid question: {}".format(r))
|
|
continue
|
|
|
|
if lhn != self.lhn:
|
|
continue
|
|
|
|
# gvfs keeps repeating itself
|
|
found = False
|
|
unicast = False
|
|
for rr in p.rr:
|
|
try:
|
|
rname = U(rr.rname).lower()
|
|
except:
|
|
self.log("invalid rr: {}".format(rr))
|
|
continue
|
|
|
|
if rname == self.lhn:
|
|
if rr.ttl > 60:
|
|
found = True
|
|
if rr.rclass == DC.F_IN:
|
|
unicast = True
|
|
|
|
if unicast:
|
|
# spec-compliant mDNS-over-unicast
|
|
srv.sck.sendto(srv.bp_ip, (cip, 5353))
|
|
elif addr[1] != 5353:
|
|
# just in case some clients use (and want us to use) invalid ports
|
|
srv.sck.sendto(srv.bp_ip, addr[:2])
|
|
|
|
if not found:
|
|
self.q[cip] = (0, srv, srv.bp_ip)
|
|
return
|
|
|
|
deadline = now + (0.5 if p.header.tc else 0.02) # rfc-7.2
|
|
|
|
# and service queries
|
|
for r in p.questions:
|
|
if not r or not r.qname:
|
|
continue
|
|
|
|
qname = U(r.qname).lower()
|
|
if qname in self.lsvcs or qname == "_services._dns-sd._udp.local.":
|
|
self.q[cip] = (deadline, srv, srv.bp_svc)
|
|
break
|
|
# heed rfc-7.1 if there was an announce in the past 12sec
|
|
# (workaround gvfs race-condition where it occasionally
|
|
# doesn't read/decode the full response...)
|
|
if now < srv.last_tx + 12:
|
|
for rr in p.rr:
|
|
if not rr.rdata:
|
|
continue
|
|
|
|
rdata = U(rr.rdata).lower()
|
|
if rdata in self.lsfqdns:
|
|
if rr.ttl > 2250:
|
|
self.q.pop(cip, None)
|
|
break
|
|
|
|
def process(self) -> None:
|
|
tx = set()
|
|
now = time.time()
|
|
cooldown = 0.9 # rfc-6: 1
|
|
if self.unsolicited and self.unsolicited[0] < now:
|
|
self.unsolicited.pop(0)
|
|
cooldown = 0.1
|
|
for srv in self.srv.values():
|
|
tx.add(srv)
|
|
|
|
if not self.unsolicited and self.args.zm_spam:
|
|
zf = time.time() + self.args.zm_spam + random.random() * 0.07
|
|
self.unsolicited.append(zf)
|
|
|
|
for srv, deadline in list(self.defend.items()):
|
|
if now < deadline:
|
|
continue
|
|
|
|
if self._tx(srv, srv.bp_ip, 0.02): # rfc-6: 0.25
|
|
self.defend.pop(srv)
|
|
|
|
for cip, (deadline, srv, msg) in list(self.q.items()):
|
|
if now < deadline:
|
|
continue
|
|
|
|
self.q.pop(cip)
|
|
self._tx(srv, msg, cooldown)
|
|
|
|
for srv in tx:
|
|
self._tx(srv, srv.bp_svc, cooldown)
|
|
|
|
def _tx(self, srv: MDNS_Sck, msg: bytes, cooldown: float) -> bool:
|
|
now = time.time()
|
|
if now < srv.last_tx + cooldown:
|
|
return False
|
|
|
|
try:
|
|
srv.sck.sendto(msg, (srv.grp, 5353))
|
|
srv.last_tx = now
|
|
except Exception as ex:
|
|
if srv.tx_ex:
|
|
return True
|
|
|
|
srv.tx_ex = True
|
|
t = "tx({},|{}|,{}): {}"
|
|
self.log(t.format(srv.ip, len(msg), cooldown, ex), 3)
|
|
|
|
return True
|