# coding: utf-8 from __future__ import print_function, unicode_literals import socket import time import ipaddress from ipaddress import ( IPv4Address, IPv4Network, IPv6Address, IPv6Network, ip_address, ip_network, ) from .__init__ import MACOS, TYPE_CHECKING from .util import Daemon, Netdev, find_prefix, min_ex, spack if TYPE_CHECKING: from .svchub import SvcHub if True: # pylint: disable=using-constant-test from typing import Optional, Union if not hasattr(socket, "IPPROTO_IPV6"): setattr(socket, "IPPROTO_IPV6", 41) class NoIPs(Exception): pass class MC_Sck(object): """there is one socket for each server ip""" def __init__( self, sck: socket.socket, nd: Netdev, grp: str, ip: str, net: Union[IPv4Network, IPv6Network], ): self.sck = sck self.idx = nd.idx self.name = nd.name self.grp = grp self.mreq = b"" self.ip = ip self.net = net self.ips = {ip: net} self.v6 = ":" in ip self.have4 = ":" not in ip self.have6 = ":" in ip class MCast(object): def __init__( self, hub: "SvcHub", Srv: type[MC_Sck], on: list[str], off: list[str], mc_grp_4: str, mc_grp_6: str, port: int, vinit: bool, ) -> None: """disable ipv%d by setting mc_grp_%d empty""" self.hub = hub self.Srv = Srv self.args = hub.args self.asrv = hub.asrv self.log_func = hub.log self.on = on self.off = off self.grp4 = mc_grp_4 self.grp6 = mc_grp_6 self.port = port self.vinit = vinit self.srv: dict[socket.socket, MC_Sck] = {} # listening sockets self.sips: set[str] = set() # all listening ips (including failed attempts) self.ll_ok: set[str] = set() # fallback linklocal IPv4 and IPv6 addresses self.b2srv: dict[bytes, MC_Sck] = {} # binary-ip -> server socket self.b4: list[bytes] = [] # sorted list of binary-ips self.b6: list[bytes] = [] # sorted list of binary-ips self.cscache: dict[str, Optional[MC_Sck]] = {} # client ip -> server cache self.running = True def log(self, msg: str, c: Union[int, str] = 0) -> None: self.log_func("multicast", msg, c) def create_servers(self) -> list[str]: bound: list[str] = [] netdevs = self.hub.tcpsrv.netdevs ips = [x[0] for x in self.hub.tcpsrv.bound] if "::" in ips: ips = [x for x in ips if x != "::"] + list( [x.split("/")[0] for x in netdevs if ":" in x] ) ips.append("0.0.0.0") if "0.0.0.0" in ips: ips = [x for x in ips if x != "0.0.0.0"] + list( [x.split("/")[0] for x in netdevs if ":" not in x] ) ips = [x for x in ips if x not in ("::1", "127.0.0.1")] ips = find_prefix(ips, netdevs) on = self.on[:] off = self.off[:] for lst in (on, off): for av in list(lst): try: arg_net = ip_network(av, False) except: arg_net = None for sk, sv in netdevs.items(): if arg_net: net_ip = ip_address(sk.split("/")[0]) if net_ip in arg_net and sk not in lst: lst.append(sk) if (av == str(sv.idx) or av == sv.name) and sk not in lst: lst.append(sk) if on: ips = [x for x in ips if x in on] elif off: ips = [x for x in ips if x not in off] if not self.grp4: ips = [x for x in ips if ":" in x] if not self.grp6: ips = [x for x in ips if ":" not in x] ips = list(set(ips)) all_selected = ips[:] # discard non-linklocal ipv6 ips = [x for x in ips if ":" not in x or x.startswith("fe80")] if not ips: raise NoIPs() for ip in ips: v6 = ":" in ip netdev = netdevs[ip] if not netdev.idx: t = "using INADDR_ANY for ip [{}], netdev [{}]" if not self.srv and ip not in ["::", "0.0.0.0"]: self.log(t.format(ip, netdev), 3) ipv = socket.AF_INET6 if v6 else socket.AF_INET sck = socket.socket(ipv, socket.SOCK_DGRAM, socket.IPPROTO_UDP) sck.settimeout(None) sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) except: pass # most ipv6 clients expect multicast on linklocal ip only; # add a/aaaa records for the other nic IPs other_ips: set[str] = set() if v6: for nd in netdevs.values(): if nd.idx == netdev.idx and nd.ip in all_selected and ":" in nd.ip: other_ips.add(nd.ip) net = ipaddress.ip_network(ip, False) ip = ip.split("/")[0] srv = self.Srv(sck, netdev, self.grp6 if ":" in ip else self.grp4, ip, net) for oth_ip in other_ips: srv.ips[oth_ip.split("/")[0]] = ipaddress.ip_network(oth_ip, False) # gvfs breaks if a linklocal ip appears in a dns reply ll = { k: v for k, v in srv.ips.items() if k.startswith("169.254") or k.startswith("fe80") } rt = {k: v for k, v in srv.ips.items() if k not in ll} if self.args.ll or not rt: self.ll_ok.update(list(ll)) if not self.args.ll: srv.ips = rt or ll if not srv.ips: self.log("no IPs on {}; skipping [{}]".format(netdev, ip), 3) continue try: self.setup_socket(srv) self.srv[sck] = srv bound.append(ip) except: t = "announce failed on {} [{}]:\n{}" self.log(t.format(netdev, ip, min_ex()), 3) if self.args.zm_msub: for s1 in self.srv.values(): for s2 in self.srv.values(): if s1.idx != s2.idx: continue if s1.ip not in s2.ips: s2.ips[s1.ip] = s1.net if self.args.zm_mnic: for s1 in self.srv.values(): for s2 in self.srv.values(): for ip1, net1 in list(s1.ips.items()): for ip2, net2 in list(s2.ips.items()): if net1 == net2 and ip1 != ip2: s1.ips[ip2] = net2 self.sips = set([x.split("/")[0] for x in all_selected]) for srv in self.srv.values(): assert srv.ip in self.sips Daemon(self.hopper, "mc-hop") return bound def setup_socket(self, srv: MC_Sck) -> None: sck = srv.sck if srv.v6: if self.vinit: zsl = list(srv.ips.keys()) self.log("v6({}) idx({}) {}".format(srv.ip, srv.idx, zsl), 6) for ip in srv.ips: bip = socket.inet_pton(socket.AF_INET6, ip) self.b2srv[bip] = srv self.b6.append(bip) grp = self.grp6 if srv.idx else "" try: if MACOS: raise Exception() sck.bind((grp, self.port, 0, srv.idx)) except: sck.bind(("", self.port, 0, srv.idx)) bgrp = socket.inet_pton(socket.AF_INET6, self.grp6) dev = spack(b"@I", srv.idx) srv.mreq = bgrp + dev if srv.idx != socket.INADDR_ANY: sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, dev) try: sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255) sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, 1) except: # macos t = "failed to set IPv6 TTL/LOOP; announcements may not survive multiple switches/routers" self.log(t, 3) else: if self.vinit: self.log("v4({}) idx({})".format(srv.ip, srv.idx), 6) bip = socket.inet_aton(srv.ip) self.b2srv[bip] = srv self.b4.append(bip) grp = self.grp4 if srv.idx else "" try: if MACOS: raise Exception() sck.bind((grp, self.port)) except: sck.bind(("", self.port)) bgrp = socket.inet_aton(self.grp4) dev = ( spack(b"=I", socket.INADDR_ANY) if srv.idx == socket.INADDR_ANY else socket.inet_aton(srv.ip) ) srv.mreq = bgrp + dev if srv.idx != socket.INADDR_ANY: sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, dev) try: sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 255) sck.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 1) except: # probably can't happen but dontcare if it does t = "failed to set IPv4 TTL/LOOP; announcements may not survive multiple switches/routers" self.log(t, 3) if self.hop(srv, False): self.log("igmp was already joined?? chilling for a sec", 3) time.sleep(1.2) self.hop(srv, True) self.b4.sort(reverse=True) self.b6.sort(reverse=True) def hop(self, srv: MC_Sck, on: bool) -> bool: """rejoin to keepalive on routers/switches without igmp-snooping""" sck = srv.sck req = srv.mreq if ":" in srv.ip: if not on: try: sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_LEAVE_GROUP, req) return True except: return False else: sck.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, req) else: if not on: try: sck.setsockopt(socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, req) return True except: return False else: # t = "joining {} from ip {} idx {} with mreq {}" # self.log(t.format(srv.grp, srv.ip, srv.idx, repr(srv.mreq)), 6) sck.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, req) return True def hopper(self): while self.args.mc_hop and self.running: time.sleep(self.args.mc_hop) if not self.running: return for srv in self.srv.values(): self.hop(srv, False) # linux does leaves/joins twice with 0.2~1.05s spacing time.sleep(1.2) if not self.running: return for srv in self.srv.values(): self.hop(srv, True) def map_client(self, cip: str) -> Optional[MC_Sck]: try: return self.cscache[cip] except: pass ret: Optional[MC_Sck] = None v6 = ":" in cip ci = IPv6Address(cip) if v6 else IPv4Address(cip) for x in self.b6 if v6 else self.b4: srv = self.b2srv[x] if any([x for x in srv.ips.values() if ci in x]): ret = srv break if not ret and cip in ("127.0.0.1", "::1"): # just give it something ret = list(self.srv.values())[0] if not ret and cip.startswith("169.254"): # idk how to map LL IPv4 msgs to nics; # just pick one and hope for the best lls = ( x for x in self.srv.values() if next((y for y in x.ips if y in self.ll_ok), None) ) ret = next(lls, None) if ret: t = "new client on {} ({}): {}" self.log(t.format(ret.name, ret.net, cip), 6) else: t = "could not map client {} to known subnet; maybe forwarded from another network?" self.log(t.format(cip), 3) if len(self.cscache) > 9000: self.cscache = {} self.cscache[cip] = ret return ret