copyparty/copyparty/httpsrv.py
2025-07-07 12:52:31 +00:00

634 lines
19 KiB
Python

# coding: utf-8
from __future__ import print_function, unicode_literals
import hashlib
import math
import os
import re
import socket
import sys
import threading
import time
import queue
from .__init__ import ANYWIN, CORES, EXE, MACOS, PY2, TYPE_CHECKING, EnvParams, unicode
try:
MNFE = ModuleNotFoundError
except:
MNFE = ImportError
try:
import jinja2
except MNFE:
if EXE:
raise
print(
"""\033[1;31m
you do not have jinja2 installed,\033[33m
choose one of these:\033[0m
* apt install python-jinja2
* {} -m pip install --user jinja2
* (try another python version, if you have one)
* (try copyparty.sfx instead)
""".format(
sys.executable
)
)
sys.exit(1)
except SyntaxError:
if EXE:
raise
print(
"""\033[1;31m
your jinja2 version is incompatible with your python version;\033[33m
please try to replace it with an older version:\033[0m
* {} -m pip install --user jinja2==2.11.3
* (try another python version, if you have one)
* (try copyparty.sfx instead)
""".format(
sys.executable
)
)
sys.exit(1)
from .httpconn import HttpConn
from .metrics import Metrics
from .u2idx import U2idx
from .util import (
E_SCK,
FHC,
CachedDict,
Daemon,
Garda,
Magician,
Netdev,
NetMap,
build_netmap,
has_resource,
ipnorm,
load_ipu,
load_resource,
min_ex,
shut_socket,
spack,
start_log_thrs,
start_stackmon,
ub64enc,
)
if TYPE_CHECKING:
from .authsrv import VFS
from .broker_util import BrokerCli
from .ssdp import SSDPr
if True: # pylint: disable=using-constant-test
from typing import Any, Optional
if PY2:
range = xrange # type: ignore
if not hasattr(socket, "AF_UNIX"):
setattr(socket, "AF_UNIX", -9001)
def load_jinja2_resource(E: EnvParams, name: str):
with load_resource(E, "web/" + name, "r") as f:
return f.read()
class HttpSrv(object):
"""
handles incoming connections using HttpConn to process http,
relying on MpSrv for performance (HttpSrv is just plain threads)
"""
def __init__(self, broker: "BrokerCli", nid: Optional[int]) -> None:
self.broker = broker
self.nid = nid
self.args = broker.args
self.E: EnvParams = self.args.E
self.log = broker.log
self.asrv = broker.asrv
# redefine in case of multiprocessing
socket.setdefaulttimeout(120)
self.t0 = time.time()
nsuf = "-n{}-i{:x}".format(nid, os.getpid()) if nid else ""
self.magician = Magician()
self.nm = NetMap([], [])
self.ssdp: Optional["SSDPr"] = None
self.gpwd = Garda(self.args.ban_pw)
self.g404 = Garda(self.args.ban_404)
self.g403 = Garda(self.args.ban_403)
self.g422 = Garda(self.args.ban_422, False)
self.gmal = Garda(self.args.ban_422)
self.gurl = Garda(self.args.ban_url)
self.bans: dict[str, int] = {}
self.aclose: dict[str, int] = {}
dli: dict[str, tuple[float, int, "VFS", str, str]] = {} # info
dls: dict[str, tuple[float, int]] = {} # state
self.dli = self.tdli = dli
self.dls = self.tdls = dls
self.iiam = '<img src="%s.cpr/iiam.gif?cache=i" />' % (self.args.SRS,)
self.bound: set[tuple[str, int]] = set()
self.name = "hsrv" + nsuf
self.mutex = threading.Lock()
self.u2mutex = threading.Lock()
self.stopping = False
self.tp_nthr = 0 # actual
self.tp_ncli = 0 # fading
self.tp_time = 0.0 # latest worker collect
self.tp_q: Optional[queue.LifoQueue[Any]] = (
None if self.args.no_htp else queue.LifoQueue()
)
self.t_periodic: Optional[threading.Thread] = None
self.u2fh = FHC()
self.u2sc: dict[str, tuple[int, "hashlib._Hash"]] = {}
self.pipes = CachedDict(0.2)
self.metrics = Metrics(self)
self.nreq = 0
self.nsus = 0
self.nban = 0
self.srvs: list[socket.socket] = []
self.ncli = 0 # exact
self.clients: set[HttpConn] = set() # laggy
self.nclimax = 0
self.cb_ts = 0.0
self.cb_v = ""
self.u2idx_free: dict[str, U2idx] = {}
self.u2idx_n = 0
assert jinja2 # type: ignore # !rm
env = jinja2.Environment()
env.loader = jinja2.FunctionLoader(lambda f: load_jinja2_resource(self.E, f))
jn = [
"browser",
"browser2",
"cf",
"idp",
"md",
"mde",
"msg",
"rups",
"shares",
"splash",
"svcs",
]
self.j2 = {x: env.get_template(x + ".html") for x in jn}
self.prism = has_resource(self.E, "web/deps/prism.js.gz")
if self.args.ipu:
self.ipu_iu, self.ipu_nm = load_ipu(self.log, self.args.ipu)
else:
self.ipu_iu = self.ipu_nm = None
self.ipa_nm = build_netmap(self.args.ipa)
self.xff_nm = build_netmap(self.args.xff_src)
self.xff_lan = build_netmap("lan")
self.mallow = "GET HEAD POST PUT DELETE OPTIONS".split()
if not self.args.no_dav:
zs = "PROPFIND PROPPATCH LOCK UNLOCK MKCOL COPY MOVE"
self.mallow += zs.split()
if self.args.zs:
from .ssdp import SSDPr
self.ssdp = SSDPr(broker)
if self.tp_q:
self.start_threads(4)
if nid:
self.tdli = {}
self.tdls = {}
if self.args.stackmon:
start_stackmon(self.args.stackmon, nid)
if self.args.log_thrs:
start_log_thrs(self.log, self.args.log_thrs, nid)
self.th_cfg: dict[str, set[str]] = {}
Daemon(self.post_init, "hsrv-init2")
def post_init(self) -> None:
try:
x = self.broker.ask("thumbsrv.getcfg")
self.th_cfg = x.get()
except:
pass
def set_netdevs(self, netdevs: dict[str, Netdev]) -> None:
ips = set()
for ip, _ in self.bound:
ips.add(ip)
self.nm = NetMap(list(ips), list(netdevs))
def start_threads(self, n: int) -> None:
self.tp_nthr += n
if self.args.log_htp:
self.log(self.name, "workers += {} = {}".format(n, self.tp_nthr), 6)
for _ in range(n):
Daemon(self.thr_poolw, self.name + "-poolw")
def stop_threads(self, n: int) -> None:
self.tp_nthr -= n
if self.args.log_htp:
self.log(self.name, "workers -= {} = {}".format(n, self.tp_nthr), 6)
assert self.tp_q # !rm
for _ in range(n):
self.tp_q.put(None)
def periodic(self) -> None:
while True:
time.sleep(2 if self.tp_ncli or self.ncli else 10)
with self.u2mutex, self.mutex:
self.u2fh.clean()
if self.tp_q:
self.tp_ncli = max(self.ncli, self.tp_ncli - 2)
if self.tp_nthr > self.tp_ncli + 8:
self.stop_threads(4)
if not self.ncli and not self.u2fh.cache and self.tp_nthr <= 8:
self.t_periodic = None
return
def listen(self, sck: socket.socket, nlisteners: int) -> None:
tcp = sck.family != socket.AF_UNIX
if self.args.j != 1:
# lost in the pickle; redefine
if not ANYWIN or self.args.reuseaddr:
sck.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if tcp:
sck.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sck.settimeout(None) # < does not inherit, ^ opts above do
if tcp:
ip, port = sck.getsockname()[:2]
else:
ip = re.sub(r"\.[0-9]+$", "", sck.getsockname().split("/")[-1])
port = 0
self.srvs.append(sck)
self.bound.add((ip, port))
self.nclimax = math.ceil(self.args.nc * 1.0 / nlisteners)
Daemon(
self.thr_listen,
"httpsrv-n{}-listen-{}-{}".format(self.nid or "0", ip, port),
(sck,),
)
def thr_listen(self, srv_sck: socket.socket) -> None:
"""listens on a shared tcp server"""
fno = srv_sck.fileno()
if srv_sck.family == socket.AF_UNIX:
ip = re.sub(r"\.[0-9]+$", "", srv_sck.getsockname())
msg = "subscribed @ %s f%d p%d" % (ip, fno, os.getpid())
ip = ip.split("/")[-1]
port = 0
tcp = False
else:
tcp = True
ip, port = srv_sck.getsockname()[:2]
hip = "[%s]" % (ip,) if ":" in ip else ip
msg = "subscribed @ %s:%d f%d p%d" % (hip, port, fno, os.getpid())
self.log(self.name, msg)
Daemon(self.broker.say, "sig-hsrv-up1", ("cb_httpsrv_up",))
saddr = ("", 0) # fwd-decl for `except TypeError as ex:`
while not self.stopping:
if self.args.log_conn:
self.log(self.name, "|%sC-ncli" % ("-" * 1,), c="90")
spins = 0
while self.ncli >= self.nclimax:
if not spins:
self.log(self.name, "at connection limit; waiting", 3)
spins += 1
time.sleep(0.1)
if spins != 50 or not self.args.aclose:
continue
ipfreq: dict[str, int] = {}
with self.mutex:
for c in self.clients:
ip = ipnorm(c.ip)
try:
ipfreq[ip] += 1
except:
ipfreq[ip] = 1
ip, n = sorted(ipfreq.items(), key=lambda x: x[1], reverse=True)[0]
if n < self.nclimax / 2:
continue
self.aclose[ip] = int(time.time() + self.args.aclose * 60)
nclose = 0
nloris = 0
nconn = 0
with self.mutex:
for c in self.clients:
cip = ipnorm(c.ip)
if ip != cip:
continue
nconn += 1
try:
if (
c.nreq >= 1
or not c.cli
or c.cli.in_hdr_recv
or c.cli.keepalive
):
Daemon(c.shutdown)
nclose += 1
if c.nreq <= 0 and (not c.cli or c.cli.in_hdr_recv):
nloris += 1
except:
pass
t = "{} downgraded to connection:close for {} min; dropped {}/{} connections"
self.log(self.name, t.format(ip, self.args.aclose, nclose, nconn), 1)
if nloris < nconn / 2:
continue
t = "slowloris (idle-conn): {} banned for {} min"
self.log(self.name, t.format(ip, self.args.loris, nclose), 1)
self.bans[ip] = int(time.time() + self.args.loris * 60)
if self.args.log_conn:
self.log(self.name, "|%sC-acc1" % ("-" * 2,), c="90")
try:
sck, saddr = srv_sck.accept()
if tcp:
cip = unicode(saddr[0])
if cip.startswith("::ffff:"):
cip = cip[7:]
addr = (cip, saddr[1])
else:
addr = ("127.8.3.7", sck.fileno())
except (OSError, socket.error) as ex:
if self.stopping:
break
self.log(self.name, "accept({}): {}".format(fno, ex), c=6)
time.sleep(0.02)
continue
except TypeError as ex:
# on macOS, accept() may return a None saddr if blocked by LittleSnitch;
# unicode(saddr[0]) ==> TypeError: 'NoneType' object is not subscriptable
if tcp and not saddr:
t = "accept(%s): failed to accept connection from client due to firewall or network issue"
self.log(self.name, t % (fno,), c=3)
try:
sck.close() # type: ignore
except:
pass
time.sleep(0.02)
continue
raise
if self.args.log_conn:
t = "|{}C-acc2 \033[0;36m{} \033[3{}m{}".format(
"-" * 3, ip, port % 8, port
)
self.log("%s %s" % addr, t, c="90")
self.accept(sck, addr)
def accept(self, sck: socket.socket, addr: tuple[str, int]) -> None:
"""takes an incoming tcp connection and creates a thread to handle it"""
now = time.time()
if now - (self.tp_time or now) > 300:
t = "httpserver threadpool died: tpt {:.2f}, now {:.2f}, nthr {}, ncli {}"
self.log(self.name, t.format(self.tp_time, now, self.tp_nthr, self.ncli), 1)
self.tp_time = 0
self.tp_q = None
with self.mutex:
self.ncli += 1
if not self.t_periodic:
name = "hsrv-pt"
if self.nid:
name += "-%d" % (self.nid,)
self.t_periodic = Daemon(self.periodic, name)
if self.tp_q:
self.tp_time = self.tp_time or now
self.tp_ncli = max(self.tp_ncli, self.ncli)
if self.tp_nthr < self.ncli + 4:
self.start_threads(8)
self.tp_q.put((sck, addr))
return
if not self.args.no_htp:
t = "looks like the httpserver threadpool died; please make an issue on github and tell me the story of how you pulled that off, thanks and dog bless\n"
self.log(self.name, t, 1)
Daemon(
self.thr_client,
"httpconn-%s-%d" % (addr[0].split(".", 2)[-1][-6:], addr[1]),
(sck, addr),
)
def thr_poolw(self) -> None:
assert self.tp_q # !rm
while True:
task = self.tp_q.get()
if not task:
break
with self.mutex:
self.tp_time = 0
try:
sck, addr = task
me = threading.current_thread()
me.name = "httpconn-%s-%d" % (addr[0].split(".", 2)[-1][-6:], addr[1])
self.thr_client(sck, addr)
me.name = self.name + "-poolw"
except Exception as ex:
if str(ex).startswith("client d/c "):
self.log(self.name, "thr_client: " + str(ex), 6)
else:
self.log(self.name, "thr_client: " + min_ex(), 3)
def shutdown(self) -> None:
self.stopping = True
for srv in self.srvs:
try:
srv.close()
except:
pass
thrs = []
clients = list(self.clients)
for cli in clients:
t = threading.Thread(target=cli.shutdown)
thrs.append(t)
t.start()
if self.tp_q:
self.stop_threads(self.tp_nthr)
for _ in range(10):
time.sleep(0.05)
if self.tp_q.empty():
break
for t in thrs:
t.join()
self.log(self.name, "ok bye")
def thr_client(self, sck: socket.socket, addr: tuple[str, int]) -> None:
"""thread managing one tcp client"""
cli = HttpConn(sck, addr, self)
with self.mutex:
self.clients.add(cli)
# print("{}\n".format(len(self.clients)), end="")
fno = sck.fileno()
try:
if self.args.log_conn:
self.log("%s %s" % addr, "|%sC-crun" % ("-" * 4,), c="90")
cli.run()
except (OSError, socket.error) as ex:
if ex.errno not in E_SCK:
self.log(
"%s %s" % addr,
"run({}): {}".format(fno, ex),
c=6,
)
finally:
sck = cli.s
if self.args.log_conn:
self.log("%s %s" % addr, "|%sC-cdone" % ("-" * 5,), c="90")
try:
fno = sck.fileno()
shut_socket(cli.log, sck)
except (OSError, socket.error) as ex:
if not MACOS:
self.log(
"%s %s" % addr,
"shut({}): {}".format(fno, ex),
c="90",
)
if ex.errno not in E_SCK:
raise
finally:
with self.mutex:
self.clients.remove(cli)
self.ncli -= 1
if cli.u2idx:
self.put_u2idx(str(addr), cli.u2idx)
def cachebuster(self) -> str:
if time.time() - self.cb_ts < 1:
return self.cb_v
with self.mutex:
if time.time() - self.cb_ts < 1:
return self.cb_v
v = self.E.t0
try:
with os.scandir(os.path.join(self.E.mod, "web")) as dh:
for fh in dh:
inf = fh.stat()
v = max(v, inf.st_mtime)
except:
pass
# spack gives 4 lsb, take 3 lsb, get 4 ch
self.cb_v = ub64enc(spack(b">L", int(v))[1:]).decode("ascii")
self.cb_ts = time.time()
return self.cb_v
def get_u2idx(self, ident: str) -> Optional[U2idx]:
utab = self.u2idx_free
for _ in range(100): # 5/0.05 = 5sec
with self.mutex:
if utab:
if ident in utab:
return utab.pop(ident)
return utab.pop(list(utab.keys())[0])
if self.u2idx_n < CORES:
self.u2idx_n += 1
return U2idx(self)
time.sleep(0.05)
# not using conditional waits, on a hunch that
# average performance will be faster like this
# since most servers won't be fully saturated
return None
def put_u2idx(self, ident: str, u2idx: U2idx) -> None:
with self.mutex:
while ident in self.u2idx_free:
ident += "a"
self.u2idx_free[ident] = u2idx
def read_dls(
self,
) -> tuple[
dict[str, tuple[float, int, str, str, str]], dict[str, tuple[float, int]]
]:
"""
mp-broker asking for local dl-info + dl-state;
reduce overhead by sending just the vfs vpath
"""
dli = {k: (a, b, c.vpath, d, e) for k, (a, b, c, d, e) in self.dli.items()}
return (dli, self.dls)
def write_dls(
self,
sdli: dict[str, tuple[float, int, str, str, str]],
dls: dict[str, tuple[float, int]],
) -> None:
"""
mp-broker pushing total dl-info + dl-state;
swap out the vfs vpath with the vfs node
"""
dli: dict[str, tuple[float, int, "VFS", str, str]] = {}
for k, (a, b, c, d, e) in sdli.items():
vn = self.asrv.vfs.all_nodes[c]
dli[k] = (a, b, vn, d, e)
self.tdli = dli
self.tdls = dls