mirror of
https://github.com/9001/copyparty.git
synced 2025-08-17 17:12:13 -06:00
398 lines
12 KiB
Python
398 lines
12 KiB
Python
# coding: utf-8
|
|
from __future__ import print_function, unicode_literals
|
|
|
|
import socket
|
|
import time
|
|
|
|
import ipaddress
|
|
from ipaddress import (
|
|
IPv4Address,
|
|
IPv4Network,
|
|
IPv6Address,
|
|
IPv6Network,
|
|
ip_address,
|
|
ip_network,
|
|
)
|
|
|
|
from .__init__ import MACOS, TYPE_CHECKING
|
|
from .util import Daemon, Netdev, find_prefix, min_ex, spack
|
|
|
|
if TYPE_CHECKING:
|
|
from .svchub import SvcHub
|
|
|
|
if True: # pylint: disable=using-constant-test
|
|
from typing import Optional, Union
|
|
|
|
if not hasattr(socket, "IPPROTO_IPV6"):
|
|
setattr(socket, "IPPROTO_IPV6", 41)
|
|
|
|
|
|
class NoIPs(Exception):
|
|
pass
|
|
|
|
|
|
class MC_Sck(object):
|
|
"""there is one socket for each server ip"""
|
|
|
|
def __init__(
|
|
self,
|
|
sck: socket.socket,
|
|
nd: Netdev,
|
|
grp: str,
|
|
ip: str,
|
|
net: Union[IPv4Network, IPv6Network],
|
|
):
|
|
self.sck = sck
|
|
self.idx = nd.idx
|
|
self.name = nd.name
|
|
self.grp = grp
|
|
self.mreq = b""
|
|
self.ip = ip
|
|
self.net = net
|
|
self.ips = {ip: net}
|
|
self.v6 = ":" in ip
|
|
self.have4 = ":" not in ip
|
|
self.have6 = ":" in ip
|
|
|
|
|
|
class MCast(object):
|
|
def __init__(
|
|
self,
|
|
hub: "SvcHub",
|
|
Srv: type[MC_Sck],
|
|
on: list[str],
|
|
off: list[str],
|
|
mc_grp_4: str,
|
|
mc_grp_6: str,
|
|
port: int,
|
|
vinit: bool,
|
|
) -> None:
|
|
"""disable ipv%d by setting mc_grp_%d empty"""
|
|
self.hub = hub
|
|
self.Srv = Srv
|
|
self.args = hub.args
|
|
self.asrv = hub.asrv
|
|
self.log_func = hub.log
|
|
self.on = on
|
|
self.off = off
|
|
self.grp4 = mc_grp_4
|
|
self.grp6 = mc_grp_6
|
|
self.port = port
|
|
self.vinit = vinit
|
|
|
|
self.srv: dict[socket.socket, MC_Sck] = {} # listening sockets
|
|
self.sips: set[str] = set() # all listening ips (including failed attempts)
|
|
self.ll_ok: set[str] = set() # fallback linklocal IPv4 and IPv6 addresses
|
|
self.b2srv: dict[bytes, MC_Sck] = {} # binary-ip -> server socket
|
|
self.b4: list[bytes] = [] # sorted list of binary-ips
|
|
self.b6: list[bytes] = [] # sorted list of binary-ips
|
|
self.cscache: dict[str, Optional[MC_Sck]] = {} # client ip -> server cache
|
|
|
|
self.running = True
|
|
|
|
def log(self, msg: str, c: Union[int, str] = 0) -> None:
|
|
self.log_func("multicast", msg, c)
|
|
|
|
def create_servers(self) -> list[str]:
|
|
bound: list[str] = []
|
|
netdevs = self.hub.tcpsrv.netdevs
|
|
ips = [x[0] for x in self.hub.tcpsrv.bound]
|
|
|
|
if "::" in ips:
|
|
ips = [x for x in ips if x != "::"] + list(
|
|
[x.split("/")[0] for x in netdevs if ":" in x]
|
|
)
|
|
ips.append("0.0.0.0")
|
|
|
|
if "0.0.0.0" in ips:
|
|
ips = [x for x in ips if x != "0.0.0.0"] + list(
|
|
[x.split("/")[0] for x in netdevs if ":" not in x]
|
|
)
|
|
|
|
ips = [x for x in ips if x not in ("::1", "127.0.0.1")]
|
|
ips = find_prefix(ips, list(netdevs))
|
|
|
|
on = self.on[:]
|
|
off = self.off[:]
|
|
for lst in (on, off):
|
|
for av in list(lst):
|
|
try:
|
|
arg_net = ip_network(av, False)
|
|
except:
|
|
arg_net = None
|
|
|
|
for sk, sv in netdevs.items():
|
|
if arg_net:
|
|
net_ip = ip_address(sk.split("/")[0])
|
|
if net_ip in arg_net and sk not in lst:
|
|
lst.append(sk)
|
|
|
|
if (av == str(sv.idx) or av == sv.name) and sk not in lst:
|
|
lst.append(sk)
|
|
|
|
if on:
|
|
ips = [x for x in ips if x in on]
|
|
elif off:
|
|
ips = [x for x in ips if x not in off]
|
|
|
|
if not self.grp4:
|
|
ips = [x for x in ips if ":" in x]
|
|
|
|
if not self.grp6:
|
|
ips = [x for x in ips if ":" not in x]
|
|
|
|
ips = list(set(ips))
|
|
all_selected = ips[:]
|
|
|
|
# discard non-linklocal ipv6
|
|
ips = [x for x in ips if ":" not in x or x.startswith("fe80")]
|
|
|
|
if not ips:
|
|
raise NoIPs()
|
|
|
|
for ip in ips:
|
|
v6 = ":" in ip
|
|
netdev = netdevs[ip]
|
|
if not netdev.idx:
|
|
t = "using INADDR_ANY for ip [{}], netdev [{}]"
|
|
if not self.srv and ip not in ["::", "0.0.0.0"]:
|
|
self.log(t.format(ip, netdev), 3)
|
|
|
|
ipv = socket.AF_INET6 if v6 else socket.AF_INET
|
|
sck = socket.socket(ipv, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
|
|
sck.settimeout(None)
|
|
sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
try:
|
|
# safe for this purpose; https://lwn.net/Articles/853637/
|
|
sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
|
except:
|
|
pass
|
|
|
|
# most ipv6 clients expect multicast on linklocal ip only;
|
|
# add a/aaaa records for the other nic IPs
|
|
other_ips: set[str] = set()
|
|
if v6:
|
|
for nd in netdevs.values():
|
|
if nd.idx == netdev.idx and nd.ip in all_selected and ":" in nd.ip:
|
|
other_ips.add(nd.ip)
|
|
|
|
net = ipaddress.ip_network(ip, False)
|
|
ip = ip.split("/")[0]
|
|
srv = self.Srv(sck, netdev, self.grp6 if ":" in ip else self.grp4, ip, net)
|
|
for oth_ip in other_ips:
|
|
srv.ips[oth_ip.split("/")[0]] = ipaddress.ip_network(oth_ip, False)
|
|
|
|
# gvfs breaks if a linklocal ip appears in a dns reply
|
|
ll = {
|
|
k: v
|
|
for k, v in srv.ips.items()
|
|
if k.startswith("169.254") or k.startswith("fe80")
|
|
}
|
|
rt = {k: v for k, v in srv.ips.items() if k not in ll}
|
|
|
|
if self.args.ll or not rt:
|
|
self.ll_ok.update(list(ll))
|
|
|
|
if not self.args.ll:
|
|
srv.ips = rt or ll
|
|
|
|
if not srv.ips:
|
|
self.log("no IPs on {}; skipping [{}]".format(netdev, ip), 3)
|
|
continue
|
|
|
|
try:
|
|
self.setup_socket(srv)
|
|
self.srv[sck] = srv
|
|
bound.append(ip)
|
|
except:
|
|
t = "announce failed on {} [{}]:\n{}"
|
|
self.log(t.format(netdev, ip, min_ex()), 3)
|
|
sck.close()
|
|
|
|
if self.args.zm_msub:
|
|
for s1 in self.srv.values():
|
|
for s2 in self.srv.values():
|
|
if s1.idx != s2.idx:
|
|
continue
|
|
|
|
if s1.ip not in s2.ips:
|
|
s2.ips[s1.ip] = s1.net
|
|
|
|
if self.args.zm_mnic:
|
|
for s1 in self.srv.values():
|
|
for s2 in self.srv.values():
|
|
for ip1, net1 in list(s1.ips.items()):
|
|
for ip2, net2 in list(s2.ips.items()):
|
|
if net1 == net2 and ip1 != ip2:
|
|
s1.ips[ip2] = net2
|
|
|
|
self.sips = set([x.split("/")[0] for x in all_selected])
|
|
for srv in self.srv.values():
|
|
assert srv.ip in self.sips
|
|
|
|
Daemon(self.hopper, "mc-hop")
|
|
return bound
|
|
|
|
def setup_socket(self, srv: MC_Sck) -> None:
|
|
sck = srv.sck
|
|
if srv.v6:
|
|
if self.vinit:
|
|
zsl = list(srv.ips.keys())
|
|
self.log("v6({}) idx({}) {}".format(srv.ip, srv.idx, zsl), 6)
|
|
|
|
for ip in srv.ips:
|
|
bip = socket.inet_pton(socket.AF_INET6, ip)
|
|
self.b2srv[bip] = srv
|
|
self.b6.append(bip)
|
|
|
|
grp = self.grp6 if srv.idx else ""
|
|
try:
|
|
if MACOS:
|
|
raise Exception()
|
|
|
|
sck.bind((grp, self.port, 0, srv.idx))
|
|
except:
|
|
sck.bind(("", self.port, 0, srv.idx))
|
|
|
|
bgrp = socket.inet_pton(socket.AF_INET6, self.grp6)
|
|
dev = spack(b"@I", srv.idx)
|
|
srv.mreq = bgrp + dev
|
|
if srv.idx != socket.INADDR_ANY:
|
|
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, dev)
|
|
|
|
try:
|
|
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255)
|
|
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, 1)
|
|
except:
|
|
# macos
|
|
t = "failed to set IPv6 TTL/LOOP; announcements may not survive multiple switches/routers"
|
|
self.log(t, 3)
|
|
else:
|
|
if self.vinit:
|
|
self.log("v4({}) idx({})".format(srv.ip, srv.idx), 6)
|
|
|
|
bip = socket.inet_aton(srv.ip)
|
|
self.b2srv[bip] = srv
|
|
self.b4.append(bip)
|
|
|
|
grp = self.grp4 if srv.idx else ""
|
|
try:
|
|
if MACOS:
|
|
raise Exception()
|
|
|
|
sck.bind((grp, self.port))
|
|
except:
|
|
sck.bind(("", self.port))
|
|
|
|
bgrp = socket.inet_aton(self.grp4)
|
|
dev = (
|
|
spack(b"=I", socket.INADDR_ANY)
|
|
if srv.idx == socket.INADDR_ANY
|
|
else socket.inet_aton(srv.ip)
|
|
)
|
|
srv.mreq = bgrp + dev
|
|
if srv.idx != socket.INADDR_ANY:
|
|
sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, dev)
|
|
|
|
try:
|
|
sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 255)
|
|
sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 1)
|
|
except:
|
|
# probably can't happen but dontcare if it does
|
|
t = "failed to set IPv4 TTL/LOOP; announcements may not survive multiple switches/routers"
|
|
self.log(t, 3)
|
|
|
|
if self.hop(srv, False):
|
|
self.log("igmp was already joined?? chilling for a sec", 3)
|
|
time.sleep(1.2)
|
|
|
|
self.hop(srv, True)
|
|
self.b4.sort(reverse=True)
|
|
self.b6.sort(reverse=True)
|
|
|
|
def hop(self, srv: MC_Sck, on: bool) -> bool:
|
|
"""rejoin to keepalive on routers/switches without igmp-snooping"""
|
|
sck = srv.sck
|
|
req = srv.mreq
|
|
if ":" in srv.ip:
|
|
if not on:
|
|
try:
|
|
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_LEAVE_GROUP, req)
|
|
return True
|
|
except:
|
|
return False
|
|
else:
|
|
sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, req)
|
|
else:
|
|
if not on:
|
|
try:
|
|
sck.setsockopt(socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, req)
|
|
return True
|
|
except:
|
|
return False
|
|
else:
|
|
# t = "joining {} from ip {} idx {} with mreq {}"
|
|
# self.log(t.format(srv.grp, srv.ip, srv.idx, repr(srv.mreq)), 6)
|
|
sck.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, req)
|
|
|
|
return True
|
|
|
|
def hopper(self):
|
|
while self.args.mc_hop and self.running:
|
|
time.sleep(self.args.mc_hop)
|
|
if not self.running:
|
|
return
|
|
|
|
for srv in self.srv.values():
|
|
self.hop(srv, False)
|
|
|
|
# linux does leaves/joins twice with 0.2~1.05s spacing
|
|
time.sleep(1.2)
|
|
if not self.running:
|
|
return
|
|
|
|
for srv in self.srv.values():
|
|
self.hop(srv, True)
|
|
|
|
def map_client(self, cip: str) -> Optional[MC_Sck]:
|
|
try:
|
|
return self.cscache[cip]
|
|
except:
|
|
pass
|
|
|
|
ret: Optional[MC_Sck] = None
|
|
v6 = ":" in cip
|
|
ci = IPv6Address(cip) if v6 else IPv4Address(cip)
|
|
for x in self.b6 if v6 else self.b4:
|
|
srv = self.b2srv[x]
|
|
if any([x for x in srv.ips.values() if ci in x]):
|
|
ret = srv
|
|
break
|
|
|
|
if not ret and cip in ("127.0.0.1", "::1"):
|
|
# just give it something
|
|
ret = list(self.srv.values())[0]
|
|
|
|
if not ret and cip.startswith("169.254"):
|
|
# idk how to map LL IPv4 msgs to nics;
|
|
# just pick one and hope for the best
|
|
lls = (
|
|
x
|
|
for x in self.srv.values()
|
|
if next((y for y in x.ips if y in self.ll_ok), None)
|
|
)
|
|
ret = next(lls, None)
|
|
|
|
if ret:
|
|
t = "new client on {} ({}): {}"
|
|
self.log(t.format(ret.name, ret.net, cip), 6)
|
|
else:
|
|
t = "could not map client {} to known subnet; maybe forwarded from another network?"
|
|
self.log(t.format(cip), 3)
|
|
|
|
if len(self.cscache) > 9000:
|
|
self.cscache = {}
|
|
|
|
self.cscache[cip] = ret
|
|
return ret
|