# coding: utf-8 from __future__ import print_function, unicode_literals import base64 import math import os import socket import sys import threading import time import queue try: import jinja2 except ImportError: print( """\033[1;31m you do not have jinja2 installed,\033[33m choose one of these:\033[0m * apt install python-jinja2 * {} -m pip install --user jinja2 * (try another python version, if you have one) * (try copyparty.sfx instead) """.format( os.path.basename(sys.executable) ) ) sys.exit(1) from .__init__ import ANYWIN, MACOS, TYPE_CHECKING, EnvParams from .bos import bos from .httpconn import HttpConn from .util import ( E_SCK, FHC, Daemon, Garda, Magician, Netdev, NetMap, ipnorm, min_ex, shut_socket, spack, start_log_thrs, start_stackmon, ) if TYPE_CHECKING: from .broker_util import BrokerCli from .ssdp import SSDPr if True: # pylint: disable=using-constant-test from typing import Any, Optional class HttpSrv(object): """ handles incoming connections using HttpConn to process http, relying on MpSrv for performance (HttpSrv is just plain threads) """ def __init__(self, broker: "BrokerCli", nid: Optional[int]) -> None: self.broker = broker self.nid = nid self.args = broker.args self.E: EnvParams = self.args.E self.log = broker.log self.asrv = broker.asrv # redefine in case of multiprocessing socket.setdefaulttimeout(120) nsuf = "-n{}-i{:x}".format(nid, os.getpid()) if nid else "" self.magician = Magician() self.nm = NetMap([], {}) self.ssdp: Optional["SSDPr"] = None self.gpwd = Garda(self.args.ban_pw) self.g404 = Garda(self.args.ban_404) self.bans: dict[str, int] = {} self.aclose: dict[str, int] = {} self.ip = "" self.port = 0 self.name = "hsrv" + nsuf self.mutex = threading.Lock() self.stopping = False self.tp_nthr = 0 # actual self.tp_ncli = 0 # fading self.tp_time = 0.0 # latest worker collect self.tp_q: Optional[queue.LifoQueue[Any]] = ( None if self.args.no_htp else queue.LifoQueue() ) self.t_periodic: Optional[threading.Thread] = None self.u2fh = FHC() self.srvs: list[socket.socket] = [] self.ncli = 0 # exact self.clients: set[HttpConn] = set() # laggy self.nclimax = 0 self.cb_ts = 0.0 self.cb_v = "" env = jinja2.Environment() env.loader = jinja2.FileSystemLoader(os.path.join(self.E.mod, "web")) jn = ["splash", "svcs", "browser", "browser2", "msg", "md", "mde", "cf"] self.j2 = {x: env.get_template(x + ".html") for x in jn} zs = os.path.join(self.E.mod, "web", "deps", "prism.js.gz") self.prism = os.path.exists(zs) if self.args.zs: from .ssdp import SSDPr self.ssdp = SSDPr(broker) cert_path = os.path.join(self.E.cfg, "cert.pem") if bos.path.exists(cert_path): self.cert_path = cert_path else: self.cert_path = "" if self.tp_q: self.start_threads(4) if nid: if self.args.stackmon: start_stackmon(self.args.stackmon, nid) if self.args.log_thrs: start_log_thrs(self.log, self.args.log_thrs, nid) self.th_cfg: dict[str, Any] = {} Daemon(self.post_init, "hsrv-init2") def post_init(self) -> None: try: x = self.broker.ask("thumbsrv.getcfg") self.th_cfg = x.get() except: pass def set_netdevs(self, netdevs: dict[str, Netdev]) -> None: self.nm = NetMap([self.ip], netdevs) def start_threads(self, n: int) -> None: self.tp_nthr += n if self.args.log_htp: self.log(self.name, "workers += {} = {}".format(n, self.tp_nthr), 6) for _ in range(n): Daemon(self.thr_poolw, self.name + "-poolw") def stop_threads(self, n: int) -> None: self.tp_nthr -= n if self.args.log_htp: self.log(self.name, "workers -= {} = {}".format(n, self.tp_nthr), 6) assert self.tp_q for _ in range(n): self.tp_q.put(None) def periodic(self) -> None: while True: time.sleep(2 if self.tp_ncli or self.ncli else 10) with self.mutex: self.u2fh.clean() if self.tp_q: self.tp_ncli = max(self.ncli, self.tp_ncli - 2) if self.tp_nthr > self.tp_ncli + 8: self.stop_threads(4) if not self.ncli and not self.u2fh.cache and self.tp_nthr <= 8: self.t_periodic = None return def listen(self, sck: socket.socket, nlisteners: int) -> None: if self.args.j != 1: # lost in the pickle; redefine if not ANYWIN or self.args.reuseaddr: sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sck.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sck.settimeout(None) # < does not inherit, ^ opts above do self.ip, self.port = sck.getsockname()[:2] self.srvs.append(sck) self.nclimax = math.ceil(self.args.nc * 1.0 / nlisteners) Daemon( self.thr_listen, "httpsrv-n{}-listen-{}-{}".format(self.nid or "0", self.ip, self.port), (sck,), ) def thr_listen(self, srv_sck: socket.socket) -> None: """listens on a shared tcp server""" ip, port = srv_sck.getsockname()[:2] fno = srv_sck.fileno() hip = "[{}]".format(ip) if ":" in ip else ip msg = "subscribed @ {}:{} f{} p{}".format(hip, port, fno, os.getpid()) self.log(self.name, msg) def fun() -> None: self.broker.say("cb_httpsrv_up") threading.Thread(target=fun, name="sig-hsrv-up1").start() while not self.stopping: if self.args.log_conn: self.log(self.name, "|%sC-ncli" % ("-" * 1,), c="90") spins = 0 while self.ncli >= self.nclimax: if not spins: self.log(self.name, "at connection limit; waiting", 3) spins += 1 time.sleep(0.1) if spins != 50 or not self.args.aclose: continue ipfreq: dict[str, int] = {} with self.mutex: for c in self.clients: ip = ipnorm(c.ip) try: ipfreq[ip] += 1 except: ipfreq[ip] = 1 ip, n = sorted(ipfreq.items(), key=lambda x: x[1], reverse=True)[0] if n < self.nclimax / 2: continue self.aclose[ip] = int(time.time() + self.args.aclose * 60) nclose = 0 nloris = 0 nconn = 0 with self.mutex: for c in self.clients: cip = ipnorm(c.ip) if ip != cip: continue nconn += 1 try: if ( c.nreq >= 1 or not c.cli or c.cli.in_hdr_recv or c.cli.keepalive ): Daemon(c.shutdown) nclose += 1 if c.nreq <= 0 and (not c.cli or c.cli.in_hdr_recv): nloris += 1 except: pass t = "{} downgraded to connection:close for {} min; dropped {}/{} connections" self.log(self.name, t.format(ip, self.args.aclose, nclose, nconn), 1) if nloris < nconn / 2: continue t = "slowloris (idle-conn): {} banned for {} min" self.log(self.name, t.format(ip, self.args.loris, nclose), 1) self.bans[ip] = int(time.time() + self.args.loris * 60) if self.args.log_conn: self.log(self.name, "|%sC-acc1" % ("-" * 2,), c="90") try: sck, saddr = srv_sck.accept() cip, cport = saddr[:2] if cip.startswith("::ffff:"): cip = cip[7:] addr = (cip, cport) except (OSError, socket.error) as ex: if self.stopping: break self.log(self.name, "accept({}): {}".format(fno, ex), c=6) time.sleep(0.02) continue if self.args.log_conn: t = "|{}C-acc2 \033[0;36m{} \033[3{}m{}".format( "-" * 3, ip, port % 8, port ) self.log("%s %s" % addr, t, c="90") self.accept(sck, addr) def accept(self, sck: socket.socket, addr: tuple[str, int]) -> None: """takes an incoming tcp connection and creates a thread to handle it""" now = time.time() if now - (self.tp_time or now) > 300: t = "httpserver threadpool died: tpt {:.2f}, now {:.2f}, nthr {}, ncli {}" self.log(self.name, t.format(self.tp_time, now, self.tp_nthr, self.ncli), 1) self.tp_time = 0 self.tp_q = None with self.mutex: self.ncli += 1 if not self.t_periodic: name = "hsrv-pt" if self.nid: name += "-{}".format(self.nid) self.t_periodic = Daemon(self.periodic, name) if self.tp_q: self.tp_time = self.tp_time or now self.tp_ncli = max(self.tp_ncli, self.ncli) if self.tp_nthr < self.ncli + 4: self.start_threads(8) self.tp_q.put((sck, addr)) return if not self.args.no_htp: t = "looks like the httpserver threadpool died; please make an issue on github and tell me the story of how you pulled that off, thanks and dog bless\n" self.log(self.name, t, 1) Daemon( self.thr_client, "httpconn-{}-{}".format(addr[0].split(".", 2)[-1][-6:], addr[1]), (sck, addr), ) def thr_poolw(self) -> None: assert self.tp_q while True: task = self.tp_q.get() if not task: break with self.mutex: self.tp_time = 0 try: sck, addr = task me = threading.current_thread() me.name = "httpconn-{}-{}".format( addr[0].split(".", 2)[-1][-6:], addr[1] ) self.thr_client(sck, addr) me.name = self.name + "-poolw" except Exception as ex: if str(ex).startswith("client d/c "): self.log(self.name, "thr_client: " + str(ex), 6) else: self.log(self.name, "thr_client: " + min_ex(), 3) def shutdown(self) -> None: self.stopping = True for srv in self.srvs: try: srv.close() except: pass thrs = [] clients = list(self.clients) for cli in clients: t = threading.Thread(target=cli.shutdown) thrs.append(t) t.start() if self.tp_q: self.stop_threads(self.tp_nthr) for _ in range(10): time.sleep(0.05) if self.tp_q.empty(): break for t in thrs: t.join() self.log(self.name, "ok bye") def thr_client(self, sck: socket.socket, addr: tuple[str, int]) -> None: """thread managing one tcp client""" cli = HttpConn(sck, addr, self) with self.mutex: self.clients.add(cli) # print("{}\n".format(len(self.clients)), end="") fno = sck.fileno() try: if self.args.log_conn: self.log("%s %s" % addr, "|%sC-crun" % ("-" * 4,), c="90") cli.run() except (OSError, socket.error) as ex: if ex.errno not in E_SCK: self.log( "%s %s" % addr, "run({}): {}".format(fno, ex), c=6, ) finally: sck = cli.s if self.args.log_conn: self.log("%s %s" % addr, "|%sC-cdone" % ("-" * 5,), c="90") try: fno = sck.fileno() shut_socket(cli.log, sck) except (OSError, socket.error) as ex: if not MACOS: self.log( "%s %s" % addr, "shut({}): {}".format(fno, ex), c="90", ) if ex.errno not in E_SCK: raise finally: with self.mutex: self.clients.remove(cli) self.ncli -= 1 def cachebuster(self) -> str: if time.time() - self.cb_ts < 1: return self.cb_v with self.mutex: if time.time() - self.cb_ts < 1: return self.cb_v v = self.E.t0 try: with os.scandir(os.path.join(self.E.mod, "web")) as dh: for fh in dh: inf = fh.stat() v = max(v, inf.st_mtime) except: pass v = base64.urlsafe_b64encode(spack(b">xxL", int(v))) self.cb_v = v.decode("ascii")[-4:] self.cb_ts = time.time() return self.cb_v