diff --git a/README.md b/README.md index 38c60727..6493f0d2 100644 --- a/README.md +++ b/README.md @@ -1200,15 +1200,18 @@ journalctl -aS '48 hour ago' -u copyparty | grep -C10 FILENAME | tee bug.log ## dev env setup -mostly optional; if you need a working env for vscode or similar +you need python 3.9 or newer due to type hints + +the rest is mostly optional; if you need a working env for vscode or similar ```sh python3 -m venv .venv . .venv/bin/activate -pip install jinja2 # mandatory +pip install jinja2 strip_hints # MANDATORY pip install mutagen # audio metadata +pip install pyftpdlib # ftp server pip install Pillow pyheif-pillow-opener pillow-avif-plugin # thumbnails -pip install black==21.12b0 bandit pylint flake8 # vscode tooling +pip install black==21.12b0 click==8.0.2 bandit pylint flake8 isort mypy # vscode tooling ``` diff --git a/bin/mtag/image-noexif.py b/bin/mtag/image-noexif.py index d5392b00..bc009c17 100644 --- a/bin/mtag/image-noexif.py +++ b/bin/mtag/image-noexif.py @@ -43,7 +43,6 @@ PS: this requires e2ts to be functional, import os import sys -import time import filecmp import subprocess as sp diff --git a/bin/up2k.py b/bin/up2k.py index 8cf904c4..23d35ea3 100755 --- a/bin/up2k.py +++ b/bin/up2k.py @@ -77,15 +77,15 @@ class File(object): self.up_b = 0 # type: int self.up_c = 0 # type: int - # m = "size({}) lmod({}) top({}) rel({}) abs({}) name({})\n" - # eprint(m.format(self.size, self.lmod, self.top, self.rel, self.abs, self.name)) + # t = "size({}) lmod({}) top({}) rel({}) abs({}) name({})\n" + # eprint(t.format(self.size, self.lmod, self.top, self.rel, self.abs, self.name)) class FileSlice(object): """file-like object providing a fixed window into a file""" def __init__(self, file, cid): - # type: (File, str) -> FileSlice + # type: (File, str) -> None self.car, self.len = file.kchunks[cid] self.cdr = self.car + self.len @@ -216,8 +216,8 @@ class CTermsize(object): eprint("\033[s\033[r\033[u") else: self.g = 1 + self.h - margin - m = "{0}\033[{1}A".format("\n" * margin, margin) - eprint("{0}\033[s\033[1;{1}r\033[u".format(m, self.g - 1)) + t = "{0}\033[{1}A".format("\n" * margin, margin) + eprint("{0}\033[s\033[1;{1}r\033[u".format(t, self.g - 1)) ss = CTermsize() @@ -597,8 +597,8 @@ class Ctl(object): if "/" in name: name = "\033[36m{0}\033[0m/{1}".format(*name.rsplit("/", 1)) - m = "{0:6.1f}% {1} {2}\033[K" - txt += m.format(p, self.nfiles - f, name) + t = "{0:6.1f}% {1} {2}\033[K" + txt += t.format(p, self.nfiles - f, name) txt += "\033[{0}H ".format(ss.g + 2) else: @@ -618,8 +618,8 @@ class Ctl(object): nleft = self.nfiles - self.up_f tail = "\033[K\033[u" if VT100 else "\r" - m = "{0} eta @ {1}/s, {2}, {3}# left".format(eta, spd, sleft, nleft) - eprint(txt + "\033]0;{0}\033\\\r{0}{1}".format(m, tail)) + t = "{0} eta @ {1}/s, {2}, {3}# left".format(eta, spd, sleft, nleft) + eprint(txt + "\033]0;{0}\033\\\r{0}{1}".format(t, tail)) def cleanup_vt100(self): ss.scroll_region(None) @@ -721,8 +721,8 @@ class Ctl(object): if search: if hs: for hit in hs: - m = "found: {0}\n {1}{2}\n" - print(m.format(upath, burl, hit["rp"]), end="") + t = "found: {0}\n {1}{2}\n" + print(t.format(upath, burl, hit["rp"]), end="") else: print("NOT found: {0}\n".format(upath), end="") diff --git a/contrib/systemd/copyparty.service b/contrib/systemd/copyparty.service index e8e02f90..fd3f9efc 100644 --- a/contrib/systemd/copyparty.service +++ b/contrib/systemd/copyparty.service @@ -4,7 +4,7 @@ # installation: # cp -pv copyparty.service /etc/systemd/system # restorecon -vr /etc/systemd/system/copyparty.service -# firewall-cmd --permanent --add-port={80,443,3923}/tcp +# firewall-cmd --permanent --add-port={80,443,3923}/tcp # --zone=libvirt # firewall-cmd --reload # systemctl daemon-reload && systemctl enable --now copyparty # diff --git a/copyparty/__init__.py b/copyparty/__init__.py index 17513708..594363eb 100644 --- a/copyparty/__init__.py +++ b/copyparty/__init__.py @@ -1,21 +1,30 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import platform -import time -import sys import os +import platform +import sys +import time + +try: + from collections.abc import Callable + + from typing import TYPE_CHECKING, Any +except: + TYPE_CHECKING = False PY2 = sys.version_info[0] == 2 if PY2: sys.dont_write_bytecode = True - unicode = unicode + unicode = unicode # noqa: F821 # pylint: disable=undefined-variable,self-assigning-variable else: unicode = str -WINDOWS = False -if platform.system() == "Windows": - WINDOWS = [int(x) for x in platform.version().split(".")] +WINDOWS: Any = ( + [int(x) for x in platform.version().split(".")] + if platform.system() == "Windows" + else False +) VT100 = not WINDOWS or WINDOWS >= [10, 0, 14393] # introduced in anniversary update @@ -25,8 +34,8 @@ ANYWIN = WINDOWS or sys.platform in ["msys"] MACOS = platform.system() == "Darwin" -def get_unixdir(): - paths = [ +def get_unixdir() -> str: + paths: list[tuple[Callable[..., str], str]] = [ (os.environ.get, "XDG_CONFIG_HOME"), (os.path.expanduser, "~/.config"), (os.environ.get, "TMPDIR"), @@ -43,7 +52,7 @@ def get_unixdir(): continue p = os.path.normpath(p) - chk(p) + chk(p) # type: ignore p = os.path.join(p, "copyparty") if not os.path.isdir(p): os.mkdir(p) @@ -56,7 +65,7 @@ def get_unixdir(): class EnvParams(object): - def __init__(self): + def __init__(self) -> None: self.t0 = time.time() self.mod = os.path.dirname(os.path.realpath(__file__)) if self.mod.endswith("__init__"): diff --git a/copyparty/__main__.py b/copyparty/__main__.py index 87203f87..5412a6c2 100644 --- a/copyparty/__main__.py +++ b/copyparty/__main__.py @@ -8,35 +8,42 @@ __copyright__ = 2019 __license__ = "MIT" __url__ = "https://github.com/9001/copyparty/" -import re -import os -import sys -import time -import shutil +import argparse import filecmp import locale -import argparse +import os +import re +import shutil +import sys import threading +import time import traceback from textwrap import dedent -from .__init__ import E, WINDOWS, ANYWIN, VT100, PY2, unicode -from .__version__ import S_VERSION, S_BUILD_DT, CODENAME -from .svchub import SvcHub -from .util import py_desc, align_tab, IMPLICATIONS, ansi_re, min_ex +from .__init__ import ANYWIN, PY2, VT100, WINDOWS, E, unicode +from .__version__ import CODENAME, S_BUILD_DT, S_VERSION from .authsrv import re_vol +from .svchub import SvcHub +from .util import IMPLICATIONS, align_tab, ansi_re, min_ex, py_desc -HAVE_SSL = True try: + from types import FrameType + + from typing import Any, Optional +except: + pass + +try: + HAVE_SSL = True import ssl except: HAVE_SSL = False -printed = [] +printed: list[str] = [] class RiceFormatter(argparse.HelpFormatter): - def _get_help_string(self, action): + def _get_help_string(self, action: argparse.Action) -> str: """ same as ArgumentDefaultsHelpFormatter(HelpFormatter) except the help += [...] line now has colors @@ -45,27 +52,27 @@ class RiceFormatter(argparse.HelpFormatter): if not VT100: fmt = " (default: %(default)s)" - ret = action.help - if "%(default)" not in action.help: + ret = str(action.help) + if "%(default)" not in ret: if action.default is not argparse.SUPPRESS: defaulting_nargs = [argparse.OPTIONAL, argparse.ZERO_OR_MORE] if action.option_strings or action.nargs in defaulting_nargs: ret += fmt return ret - def _fill_text(self, text, width, indent): + def _fill_text(self, text: str, width: int, indent: str) -> str: """same as RawDescriptionHelpFormatter(HelpFormatter)""" return "".join(indent + line + "\n" for line in text.splitlines()) class Dodge11874(RiceFormatter): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs["width"] = 9003 super(Dodge11874, self).__init__(*args, **kwargs) -def lprint(*a, **ka): - txt = " ".join(unicode(x) for x in a) + ka.get("end", "\n") +def lprint(*a: Any, **ka: Any) -> None: + txt: str = " ".join(unicode(x) for x in a) + ka.get("end", "\n") printed.append(txt) if not VT100: txt = ansi_re.sub("", txt) @@ -73,11 +80,11 @@ def lprint(*a, **ka): print(txt, **ka) -def warn(msg): +def warn(msg: str) -> None: lprint("\033[1mwarning:\033[0;33m {}\033[0m\n".format(msg)) -def ensure_locale(): +def ensure_locale() -> None: for x in [ "en_US.UTF-8", "English_United States.UTF8", @@ -91,7 +98,7 @@ def ensure_locale(): continue -def ensure_cert(): +def ensure_cert() -> None: """ the default cert (and the entire TLS support) is only here to enable the crypto.subtle javascript API, which is necessary due to the webkit guys @@ -117,8 +124,8 @@ def ensure_cert(): # printf 'NO\n.\n.\n.\n.\ncopyparty-insecure\n.\n' | faketime '2000-01-01 00:00:00' openssl req -x509 -sha256 -newkey rsa:2048 -keyout insecure.pem -out insecure.pem -days $((($(printf %d 0x7fffffff)-$(date +%s --date=2000-01-01T00:00:00Z))/(60*60*24))) -nodes && ls -al insecure.pem && openssl x509 -in insecure.pem -text -noout -def configure_ssl_ver(al): - def terse_sslver(txt): +def configure_ssl_ver(al: argparse.Namespace) -> None: + def terse_sslver(txt: str) -> str: txt = txt.lower() for c in ["_", "v", "."]: txt = txt.replace(c, "") @@ -133,8 +140,8 @@ def configure_ssl_ver(al): flags = [k for k in ssl.__dict__ if ptn.match(k)] # SSLv2 SSLv3 TLSv1 TLSv1_1 TLSv1_2 TLSv1_3 if "help" in sslver: - avail = [terse_sslver(x[6:]) for x in flags] - avail = " ".join(sorted(avail) + ["all"]) + avail1 = [terse_sslver(x[6:]) for x in flags] + avail = " ".join(sorted(avail1) + ["all"]) lprint("\navailable ssl/tls versions:\n " + avail) sys.exit(0) @@ -160,7 +167,7 @@ def configure_ssl_ver(al): # think i need that beer now -def configure_ssl_ciphers(al): +def configure_ssl_ciphers(al: argparse.Namespace) -> None: ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) if al.ssl_ver: ctx.options &= ~al.ssl_flags_en @@ -184,8 +191,8 @@ def configure_ssl_ciphers(al): sys.exit(0) -def args_from_cfg(cfg_path): - ret = [] +def args_from_cfg(cfg_path: str) -> list[str]: + ret: list[str] = [] skip = False with open(cfg_path, "rb") as f: for ln in [x.decode("utf-8").strip() for x in f]: @@ -210,29 +217,30 @@ def args_from_cfg(cfg_path): return ret -def sighandler(sig=None, frame=None): +def sighandler(sig: Optional[int] = None, frame: Optional[FrameType] = None) -> None: msg = [""] * 5 for th in threading.enumerate(): + stk = sys._current_frames()[th.ident] # type: ignore msg.append(str(th)) - msg.extend(traceback.format_stack(sys._current_frames()[th.ident])) + msg.extend(traceback.format_stack(stk)) msg.append("\n") print("\n".join(msg)) -def disable_quickedit(): - import ctypes +def disable_quickedit() -> None: import atexit + import ctypes from ctypes import wintypes - def ecb(ok, fun, args): + def ecb(ok: bool, fun: Any, args: list[Any]) -> list[Any]: if not ok: - err = ctypes.get_last_error() + err: int = ctypes.get_last_error() # type: ignore if err: - raise ctypes.WinError(err) + raise ctypes.WinError(err) # type: ignore return args - k32 = ctypes.WinDLL("kernel32", use_last_error=True) + k32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore if PY2: wintypes.LPDWORD = ctypes.POINTER(wintypes.DWORD) @@ -242,14 +250,14 @@ def disable_quickedit(): k32.GetConsoleMode.argtypes = (wintypes.HANDLE, wintypes.LPDWORD) k32.SetConsoleMode.argtypes = (wintypes.HANDLE, wintypes.DWORD) - def cmode(out, mode=None): + def cmode(out: bool, mode: Optional[int] = None) -> int: h = k32.GetStdHandle(-11 if out else -10) if mode: - return k32.SetConsoleMode(h, mode) + return k32.SetConsoleMode(h, mode) # type: ignore - mode = wintypes.DWORD() - k32.GetConsoleMode(h, ctypes.byref(mode)) - return mode.value + cmode = wintypes.DWORD() + k32.GetConsoleMode(h, ctypes.byref(cmode)) + return cmode.value # disable quickedit mode = orig_in = cmode(False) @@ -268,7 +276,7 @@ def disable_quickedit(): cmode(True, mode | 4) -def run_argparse(argv, formatter): +def run_argparse(argv: list[str], formatter: Any) -> argparse.Namespace: ap = argparse.ArgumentParser( formatter_class=formatter, prog="copyparty", @@ -596,7 +604,7 @@ def run_argparse(argv, formatter): return ret -def main(argv=None): +def main(argv: Optional[list[str]] = None) -> None: time.strptime("19970815", "%Y%m%d") # python#7980 if WINDOWS: os.system("rem") # enables colors @@ -618,7 +626,7 @@ def main(argv=None): supp = args_from_cfg(v) argv.extend(supp) - deprecated = [] + deprecated: list[tuple[str, str]] = [] for dk, nk in deprecated: try: idx = argv.index(dk) @@ -650,7 +658,7 @@ def main(argv=None): if not VT100: al.wintitle = "" - nstrs = [] + nstrs: list[str] = [] anymod = False for ostr in al.v or []: m = re_vol.match(ostr) diff --git a/copyparty/authsrv.py b/copyparty/authsrv.py index 4ca406ee..92612529 100644 --- a/copyparty/authsrv.py +++ b/copyparty/authsrv.py @@ -1,44 +1,68 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import re -import os -import sys -import stat -import time +import argparse import base64 import hashlib +import os +import re +import stat +import sys import threading +import time from datetime import datetime -from .__init__ import ANYWIN, WINDOWS +from .__init__ import ANYWIN, TYPE_CHECKING, WINDOWS +from .bos import bos from .util import ( IMPLICATIONS, META_NOBOTS, + Pebkac, + absreal, + fsenc, + relchk, + statdir, uncyg, undot, - relchk, unhumanize, - absreal, - Pebkac, - fsenc, - statdir, ) -from .bos import bos + +try: + from collections.abc import Iterable + + import typing + from typing import Any, Generator, Optional, Union + + from .util import RootLogger +except: + pass + +if TYPE_CHECKING: + pass + # Vflags: TypeAlias = dict[str, str | bool | float | list[str]] + # Vflags: TypeAlias = dict[str, Any] + # Mflags: TypeAlias = dict[str, Vflags] LEELOO_DALLAS = "leeloo_dallas" class AXS(object): - def __init__(self, uread=None, uwrite=None, umove=None, udel=None, uget=None): - self.uread = {} if uread is None else {k: 1 for k in uread} - self.uwrite = {} if uwrite is None else {k: 1 for k in uwrite} - self.umove = {} if umove is None else {k: 1 for k in umove} - self.udel = {} if udel is None else {k: 1 for k in udel} - self.uget = {} if uget is None else {k: 1 for k in uget} + def __init__( + self, + uread: Optional[Union[list[str], set[str]]] = None, + uwrite: Optional[Union[list[str], set[str]]] = None, + umove: Optional[Union[list[str], set[str]]] = None, + udel: Optional[Union[list[str], set[str]]] = None, + uget: Optional[Union[list[str], set[str]]] = None, + ) -> None: + self.uread: set[str] = set(uread or []) + self.uwrite: set[str] = set(uwrite or []) + self.umove: set[str] = set(umove or []) + self.udel: set[str] = set(udel or []) + self.uget: set[str] = set(uget or []) - def __repr__(self): + def __repr__(self) -> str: return "AXS({})".format( ", ".join( "{}={!r}".format(k, self.__dict__[k]) @@ -48,33 +72,33 @@ class AXS(object): class Lim(object): - def __init__(self): - self.nups = {} # num tracker - self.bups = {} # byte tracker list - self.bupc = {} # byte tracker cache + def __init__(self) -> None: + self.nups: dict[str, list[float]] = {} # num tracker + self.bups: dict[str, list[tuple[float, int]]] = {} # byte tracker list + self.bupc: dict[str, int] = {} # byte tracker cache self.nosub = False # disallow subdirectories - self.smin = None # filesize min - self.smax = None # filesize max + self.smin = -1 # filesize min + self.smax = -1 # filesize max - self.bwin = None # bytes window - self.bmax = None # bytes max - self.nwin = None # num window - self.nmax = None # num max + self.bwin = 0 # bytes window + self.bmax = 0 # bytes max + self.nwin = 0 # num window + self.nmax = 0 # num max - self.rotn = None # rot num files - self.rotl = None # rot depth - self.rotf = None # rot datefmt - self.rot_re = None # rotf check + self.rotn = 0 # rot num files + self.rotl = 0 # rot depth + self.rotf = "" # rot datefmt + self.rot_re = re.compile("") # rotf check - def set_rotf(self, fmt): + def set_rotf(self, fmt: str) -> None: self.rotf = fmt r = re.escape(fmt).replace("%Y", "[0-9]{4}").replace("%j", "[0-9]{3}") r = re.sub("%[mdHMSWU]", "[0-9]{2}", r) self.rot_re = re.compile("(^|/)" + r + "$") - def all(self, ip, rem, sz, abspath): + def all(self, ip: str, rem: str, sz: float, abspath: str) -> tuple[str, str]: self.chk_nup(ip) self.chk_bup(ip) self.chk_rem(rem) @@ -87,18 +111,18 @@ class Lim(object): return ap2, ("{}/{}".format(rem, vp2) if rem else vp2) - def chk_sz(self, sz): - if self.smin is not None and sz < self.smin: + def chk_sz(self, sz: float) -> None: + if self.smin != -1 and sz < self.smin: raise Pebkac(400, "file too small") - if self.smax is not None and sz > self.smax: + if self.smax != -1 and sz > self.smax: raise Pebkac(400, "file too big") - def chk_rem(self, rem): + def chk_rem(self, rem: str) -> None: if self.nosub and rem: raise Pebkac(500, "no subdirectories allowed") - def rot(self, path): + def rot(self, path: str) -> tuple[str, str]: if not self.rotf and not self.rotn: return path, "" @@ -120,7 +144,7 @@ class Lim(object): d = ret[len(path) :].strip("/\\").replace("\\", "/") return ret, d - def dive(self, path, lvs): + def dive(self, path: str, lvs: int) -> Optional[str]: items = bos.listdir(path) if not lvs: @@ -155,14 +179,14 @@ class Lim(object): return os.path.join(sub, ret) - def nup(self, ip): + def nup(self, ip: str) -> None: try: self.nups[ip].append(time.time()) except: self.nups[ip] = [time.time()] - def bup(self, ip, nbytes): - v = [time.time(), nbytes] + def bup(self, ip: str, nbytes: int) -> None: + v = (time.time(), nbytes) try: self.bups[ip].append(v) self.bupc[ip] += nbytes @@ -170,7 +194,7 @@ class Lim(object): self.bups[ip] = [v] self.bupc[ip] = nbytes - def chk_nup(self, ip): + def chk_nup(self, ip: str) -> None: if not self.nmax or ip not in self.nups: return @@ -182,7 +206,7 @@ class Lim(object): if len(nups) >= self.nmax: raise Pebkac(429, "too many uploads") - def chk_bup(self, ip): + def chk_bup(self, ip: str) -> None: if not self.bmax or ip not in self.bups: return @@ -200,35 +224,37 @@ class Lim(object): class VFS(object): """single level in the virtual fs""" - def __init__(self, log, realpath, vpath, axs, flags): + def __init__( + self, + log: Optional[RootLogger], + realpath: str, + vpath: str, + axs: AXS, + flags: dict[str, Any], + ) -> None: self.log = log self.realpath = realpath # absolute path on host filesystem self.vpath = vpath # absolute path in the virtual filesystem - self.axs = axs # type: AXS + self.axs = axs self.flags = flags # config options - self.nodes = {} # child nodes - self.histtab = None # all realpath->histpath - self.dbv = None # closest full/non-jump parent - self.lim = None # type: Lim # upload limits; only set for dbv + self.nodes: dict[str, VFS] = {} # child nodes + self.histtab: dict[str, str] = {} # all realpath->histpath + self.dbv: Optional[VFS] = None # closest full/non-jump parent + self.lim: Optional[Lim] = None # upload limits; only set for dbv + self.aread: dict[str, list[str]] = {} + self.awrite: dict[str, list[str]] = {} + self.amove: dict[str, list[str]] = {} + self.adel: dict[str, list[str]] = {} + self.aget: dict[str, list[str]] = {} if realpath: self.histpath = os.path.join(realpath, ".hist") # db / thumbcache self.all_vols = {vpath: self} # flattened recursive - self.aread = {} - self.awrite = {} - self.amove = {} - self.adel = {} - self.aget = {} else: - self.histpath = None - self.all_vols = None - self.aread = None - self.awrite = None - self.amove = None - self.adel = None - self.aget = None + self.histpath = "" + self.all_vols = {} - def __repr__(self): + def __repr__(self) -> str: return "VFS({})".format( ", ".join( "{}={!r}".format(k, self.__dict__[k]) @@ -236,14 +262,14 @@ class VFS(object): ) ) - def get_all_vols(self, outdict): + def get_all_vols(self, outdict: dict[str, "VFS"]) -> None: if self.realpath: outdict[self.vpath] = self for v in self.nodes.values(): v.get_all_vols(outdict) - def add(self, src, dst): + def add(self, src: str, dst: str) -> "VFS": """get existing, or add new path to the vfs""" assert not src.endswith("/") # nosec assert not dst.endswith("/") # nosec @@ -257,7 +283,7 @@ class VFS(object): vn = VFS( self.log, - os.path.join(self.realpath, name) if self.realpath else None, + os.path.join(self.realpath, name) if self.realpath else "", "{}/{}".format(self.vpath, name).lstrip("/"), self.axs, self._copy_flags(name), @@ -277,7 +303,7 @@ class VFS(object): self.nodes[dst] = vn return vn - def _copy_flags(self, name): + def _copy_flags(self, name: str) -> dict[str, Any]: flags = {k: v for k, v in self.flags.items()} hist = flags.get("hist") if hist and hist != "-": @@ -285,20 +311,20 @@ class VFS(object): return flags - def bubble_flags(self): + def bubble_flags(self) -> None: if self.dbv: for k, v in self.dbv.flags.items(): if k not in ["hist"]: self.flags[k] = v - for v in self.nodes.values(): - v.bubble_flags() + for n in self.nodes.values(): + n.bubble_flags() - def _find(self, vpath): + def _find(self, vpath: str) -> tuple["VFS", str]: """return [vfs,remainder]""" vpath = undot(vpath) if vpath == "": - return [self, ""] + return self, "" if "/" in vpath: name, rem = vpath.split("/", 1) @@ -309,66 +335,64 @@ class VFS(object): if name in self.nodes: return self.nodes[name]._find(rem) - return [self, vpath] + return self, vpath - def can_access(self, vpath, uname): - # type: (str, str) -> tuple[bool, bool, bool, bool] + def can_access(self, vpath: str, uname: str) -> tuple[bool, bool, bool, bool, bool]: """can Read,Write,Move,Delete,Get""" vn, _ = self._find(vpath) c = vn.axs - return [ + return ( uname in c.uread or "*" in c.uread, uname in c.uwrite or "*" in c.uwrite, uname in c.umove or "*" in c.umove, uname in c.udel or "*" in c.udel, uname in c.uget or "*" in c.uget, - ] + ) def get( self, - vpath, - uname, - will_read, - will_write, - will_move=False, - will_del=False, - will_get=False, - ): - # type: (str, str, bool, bool, bool, bool, bool) -> tuple[VFS, str] + vpath: str, + uname: str, + will_read: bool, + will_write: bool, + will_move: bool = False, + will_del: bool = False, + will_get: bool = False, + ) -> tuple["VFS", str]: """returns [vfsnode,fs_remainder] if user has the requested permissions""" if ANYWIN: mod = relchk(vpath) if mod: - self.log("vfs", "invalid relpath [{}]".format(vpath)) + if self.log: + self.log("vfs", "invalid relpath [{}]".format(vpath)) raise Pebkac(404) vn, rem = self._find(vpath) - c = vn.axs + c: AXS = vn.axs for req, d, msg in [ - [will_read, c.uread, "read"], - [will_write, c.uwrite, "write"], - [will_move, c.umove, "move"], - [will_del, c.udel, "delete"], - [will_get, c.uget, "get"], + (will_read, c.uread, "read"), + (will_write, c.uwrite, "write"), + (will_move, c.umove, "move"), + (will_del, c.udel, "delete"), + (will_get, c.uget, "get"), ]: if req and (uname not in d and "*" not in d) and uname != LEELOO_DALLAS: - m = "you don't have {}-access for this location" - raise Pebkac(403, m.format(msg)) + t = "you don't have {}-access for this location" + raise Pebkac(403, t.format(msg)) return vn, rem - def get_dbv(self, vrem): - # type: (str) -> tuple[VFS, str] + def get_dbv(self, vrem: str) -> tuple["VFS", str]: dbv = self.dbv if not dbv: return self, vrem - vrem = [self.vpath[len(dbv.vpath) :].lstrip("/"), vrem] - vrem = "/".join([x for x in vrem if x]) + tv = [self.vpath[len(dbv.vpath) :].lstrip("/"), vrem] + vrem = "/".join([x for x in tv if x]) return dbv, vrem - def canonical(self, rem, resolve=True): + def canonical(self, rem: str, resolve: bool = True) -> str: """returns the canonical path (fully-resolved absolute fs path)""" rp = self.realpath if rem: @@ -376,8 +400,14 @@ class VFS(object): return absreal(rp) if resolve else rp - def ls(self, rem, uname, scandir, permsets, lstat=False): - # type: (str, str, bool, list[list[bool]], bool) -> tuple[str, str, dict[str, VFS]] + def ls( + self, + rem: str, + uname: str, + scandir: bool, + permsets: list[list[bool]], + lstat: bool = False, + ) -> tuple[str, list[tuple[str, os.stat_result]], dict[str, "VFS"]]: """return user-readable [fsdir,real,virt] items at vpath""" virt_vis = {} # nodes readable by user abspath = self.canonical(rem) @@ -389,8 +419,8 @@ class VFS(object): for name, vn2 in sorted(self.nodes.items()): ok = False - axs = vn2.axs - axs = [axs.uread, axs.uwrite, axs.umove, axs.udel, axs.uget] + zx = vn2.axs + axs = [zx.uread, zx.uwrite, zx.umove, zx.udel, zx.uget] for pset in permsets: ok = True for req, lst in zip(pset, axs): @@ -409,9 +439,32 @@ class VFS(object): elif "/.hist/th/" in p: real = [x for x in real if not x[0].endswith("dir.txt")] - return [abspath, real, virt_vis] + return abspath, real, virt_vis - def walk(self, rel, rem, seen, uname, permsets, dots, scandir, lstat, subvols=True): + def walk( + self, + rel: str, + rem: str, + seen: list[str], + uname: str, + permsets: list[list[bool]], + dots: bool, + scandir: bool, + lstat: bool, + subvols: bool = True, + ) -> Generator[ + tuple[ + "VFS", + str, + str, + str, + list[tuple[str, os.stat_result]], + list[tuple[str, os.stat_result]], + dict[str, "VFS"], + ], + None, + None, + ]: """ recursively yields from ./rem; rel is a unix-style user-defined vpath (not vfs-related) @@ -425,8 +478,9 @@ class VFS(object): and (not fsroot.startswith(seen[-1]) or fsroot == seen[-1]) and fsroot in seen ): - m = "bailing from symlink loop,\n prev: {}\n curr: {}\n from: {}/{}" - self.log("vfs.walk", m.format(seen[-1], fsroot, self.vpath, rem), 3) + if self.log: + t = "bailing from symlink loop,\n prev: {}\n curr: {}\n from: {}/{}" + self.log("vfs.walk", t.format(seen[-1], fsroot, self.vpath, rem), 3) return seen = seen[:] + [fsroot] @@ -460,9 +514,9 @@ class VFS(object): for x in vfs.walk(wrel, "", seen, uname, permsets, dots, scandir, lstat): yield x - def zipgen(self, vrem, flt, uname, dots, scandir): - if flt: - flt = {k: True for k in flt} + def zipgen( + self, vrem: str, flt: set[str], uname: str, dots: bool, scandir: bool + ) -> Generator[dict[str, Any], None, None]: # if multiselect: add all items to archive root # if single folder: the folder itself is the top-level item @@ -473,33 +527,33 @@ class VFS(object): if flt: files = [x for x in files if x[0] in flt] - rm = [x for x in rd if x[0] not in flt] - [rd.remove(x) for x in rm] + rm1 = [x for x in rd if x[0] not in flt] + _ = [rd.remove(x) for x in rm1] # type: ignore - rm = [x for x in vd.keys() if x not in flt] - [vd.pop(x) for x in rm] + rm2 = [x for x in vd.keys() if x not in flt] + _ = [vd.pop(x) for x in rm2] - flt = None + flt = set() # print(repr([vpath, apath, [x[0] for x in files]])) fnames = [n[0] for n in files] vpaths = [vpath + "/" + n for n in fnames] if vpath else fnames apaths = [os.path.join(apath, n) for n in fnames] - files = list(zip(vpaths, apaths, files)) + ret = list(zip(vpaths, apaths, files)) if not dots: # dotfile filtering based on vpath (intended visibility) - files = [x for x in files if "/." not in "/" + x[0]] + ret = [x for x in ret if "/." not in "/" + x[0]] - rm = [x for x in rd if x[0].startswith(".")] - for x in rm: - rd.remove(x) + zel = [ze for ze in rd if ze[0].startswith(".")] + for ze in zel: + rd.remove(ze) - rm = [k for k in vd.keys() if k.startswith(".")] - for x in rm: - del vd[x] + zsl = [zs for zs in vd.keys() if zs.startswith(".")] + for zs in zsl: + del vd[zs] - for f in [{"vp": v, "ap": a, "st": n[1]} for v, a, n in files]: + for f in [{"vp": v, "ap": a, "st": n[1]} for v, a, n in ret]: yield f @@ -512,7 +566,12 @@ else: class AuthSrv(object): """verifies users against given paths""" - def __init__(self, args, log_func, warn_anonwrite=True): + def __init__( + self, + args: argparse.Namespace, + log_func: Optional[RootLogger], + warn_anonwrite: bool = True, + ) -> None: self.args = args self.log_func = log_func self.warn_anonwrite = warn_anonwrite @@ -521,11 +580,11 @@ class AuthSrv(object): self.mutex = threading.Lock() self.reload() - def log(self, msg, c=0): + def log(self, msg: str, c: Union[int, str] = 0) -> None: if self.log_func: self.log_func("auth", msg, c) - def laggy_iter(self, iterable): + def laggy_iter(self, iterable: Iterable[Any]) -> Generator[Any, None, None]: """returns [value,isFinalValue]""" it = iter(iterable) prev = next(it) @@ -535,26 +594,39 @@ class AuthSrv(object): yield prev, True - def _map_volume(self, src, dst, mount, daxs, mflags): + def _map_volume( + self, + src: str, + dst: str, + mount: dict[str, str], + daxs: dict[str, AXS], + mflags: dict[str, dict[str, Any]], + ) -> None: if dst in mount: - m = "multiple filesystem-paths mounted at [/{}]:\n [{}]\n [{}]" - self.log(m.format(dst, mount[dst], src), c=1) + t = "multiple filesystem-paths mounted at [/{}]:\n [{}]\n [{}]" + self.log(t.format(dst, mount[dst], src), c=1) raise Exception("invalid config") if src in mount.values(): - m = "warning: filesystem-path [{}] mounted in multiple locations:" - m = m.format(src) + t = "warning: filesystem-path [{}] mounted in multiple locations:" + t = t.format(src) for v in [k for k, v in mount.items() if v == src] + [dst]: - m += "\n /{}".format(v) + t += "\n /{}".format(v) - self.log(m, c=3) + self.log(t, c=3) mount[dst] = src daxs[dst] = AXS() mflags[dst] = {} - def _parse_config_file(self, fd, acct, daxs, mflags, mount): - # type: (any, str, dict[str, AXS], any, str) -> None + def _parse_config_file( + self, + fd: typing.BinaryIO, + acct: dict[str, str], + daxs: dict[str, AXS], + mflags: dict[str, dict[str, Any]], + mount: dict[str, str], + ) -> None: skip = False vol_src = None vol_dst = None @@ -601,23 +673,25 @@ class AuthSrv(object): uname = "*" if lvl == "a": - m = "WARNING (config-file): permission flag 'a' is deprecated; please use 'rw' instead" - self.log(m, 1) + t = "WARNING (config-file): permission flag 'a' is deprecated; please use 'rw' instead" + self.log(t, 1) self._read_vol_str(lvl, uname, daxs[vol_dst], mflags[vol_dst]) - def _read_vol_str(self, lvl, uname, axs, flags): - # type: (str, str, AXS, any) -> None + def _read_vol_str( + self, lvl: str, uname: str, axs: AXS, flags: dict[str, Any] + ) -> None: if lvl.strip("crwmdg"): raise Exception("invalid volume flag: {},{}".format(lvl, uname)) if lvl == "c": + cval: Union[bool, str] = True try: # volume flag with arguments, possibly with a preceding list of bools uname, cval = uname.split("=", 1) except: # just one or more bools - cval = True + pass while "," in uname: # one or more bools before the final flag; eat them @@ -631,34 +705,38 @@ class AuthSrv(object): uname = "*" for un in uname.replace(",", " ").strip().split(): - if "r" in lvl: - axs.uread[un] = 1 + for ch, al in [ + ("r", axs.uread), + ("w", axs.uwrite), + ("m", axs.umove), + ("d", axs.udel), + ("g", axs.uget), + ]: + if ch in lvl: + al.add(un) - if "w" in lvl: - axs.uwrite[un] = 1 - - if "m" in lvl: - axs.umove[un] = 1 - - if "d" in lvl: - axs.udel[un] = 1 - - if "g" in lvl: - axs.uget[un] = 1 - - def _read_volflag(self, flags, name, value, is_list): + def _read_volflag( + self, + flags: dict[str, Any], + name: str, + value: Union[str, bool, list[str]], + is_list: bool, + ) -> None: if name not in ["mtp"]: flags[name] = value return - if not is_list: - value = [value] - elif not value: + vals = flags.get(name, []) + if not value: return + elif is_list: + vals += value + else: + vals += [value] - flags[name] = flags.get(name, []) + value + flags[name] = vals - def reload(self): + def reload(self) -> None: """ construct a flat list of mountpoints and usernames first from the commandline arguments @@ -666,10 +744,10 @@ class AuthSrv(object): before finally building the VFS """ - acct = {} # username:password - daxs = {} # type: dict[str, AXS] - mflags = {} # mountpoint:[flag] - mount = {} # dst:src (mountpoint:realpath) + acct: dict[str, str] = {} # username:password + daxs: dict[str, AXS] = {} + mflags: dict[str, dict[str, Any]] = {} # moutpoint:flags + mount: dict[str, str] = {} # dst:src (mountpoint:realpath) if self.args.a: # list of username:password @@ -678,8 +756,8 @@ class AuthSrv(object): u, p = x.split(":", 1) acct[u] = p except: - m = '\n invalid value "{}" for argument -a, must be username:password' - raise Exception(m.format(x)) + t = '\n invalid value "{}" for argument -a, must be username:password' + raise Exception(t.format(x)) if self.args.v: # list of src:dst:permset:permset:... @@ -708,8 +786,8 @@ class AuthSrv(object): try: self._parse_config_file(f, acct, daxs, mflags, mount) except: - m = "\n\033[1;31m\nerror in config file {} on line {}:\n\033[0m" - self.log(m.format(cfg_fn, self.line_ctr), 1) + t = "\n\033[1;31m\nerror in config file {} on line {}:\n\033[0m" + self.log(t.format(cfg_fn, self.line_ctr), 1) raise # case-insensitive; normalize @@ -726,7 +804,7 @@ class AuthSrv(object): vfs = VFS(self.log_func, bos.path.abspath("."), "", axs, {}) elif "" not in mount: # there's volumes but no root; make root inaccessible - vfs = VFS(self.log_func, None, "", AXS(), {}) + vfs = VFS(self.log_func, "", "", AXS(), {}) vfs.flags["d2d"] = True maxdepth = 0 @@ -740,10 +818,10 @@ class AuthSrv(object): vfs = VFS(self.log_func, mount[dst], dst, daxs[dst], mflags[dst]) continue - v = vfs.add(mount[dst], dst) - v.axs = daxs[dst] - v.flags = mflags[dst] - v.dbv = None + zv = vfs.add(mount[dst], dst) + zv.axs = daxs[dst] + zv.flags = mflags[dst] + zv.dbv = None vfs.all_vols = {} vfs.get_all_vols(vfs.all_vols) @@ -751,11 +829,11 @@ class AuthSrv(object): for perm in "read write move del get".split(): axs_key = "u" + perm unames = ["*"] + list(acct.keys()) - umap = {x: [] for x in unames} + umap: dict[str, list[str]] = {x: [] for x in unames} for usr in unames: for vp, vol in vfs.all_vols.items(): - axs = getattr(vol.axs, axs_key) - if usr in axs or "*" in axs: + zx = getattr(vol.axs, axs_key) + if usr in zx or "*" in zx: umap[usr].append(vp) umap[usr].sort() setattr(vfs, "a" + perm, umap) @@ -764,7 +842,7 @@ class AuthSrv(object): missing_users = {} for axs in daxs.values(): for d in [axs.uread, axs.uwrite, axs.umove, axs.udel, axs.uget]: - for usr in d.keys(): + for usr in d: all_users[usr] = 1 if usr != "*" and usr not in acct: missing_users[usr] = 1 @@ -783,8 +861,8 @@ class AuthSrv(object): promote = [] demote = [] for vol in vfs.all_vols.values(): - hid = hashlib.sha512(fsenc(vol.realpath)).digest() - hid = base64.b32encode(hid).decode("ascii").lower() + zb = hashlib.sha512(fsenc(vol.realpath)).digest() + hid = base64.b32encode(zb).decode("ascii").lower() vflag = vol.flags.get("hist") if vflag == "-": pass @@ -822,21 +900,21 @@ class AuthSrv(object): demote.append(vol) # discard jump-vols - for v in demote: - vfs.all_vols.pop(v.vpath) + for zv in demote: + vfs.all_vols.pop(zv.vpath) if promote: - msg = [ + ta = [ "\n the following jump-volumes were generated to assist the vfs.\n As they contain a database (probably from v0.11.11 or older),\n they are promoted to full volumes:" ] for vol in promote: - msg.append( + ta.append( " /{} ({}) ({})".format(vol.vpath, vol.realpath, vol.histpath) ) - self.log("\n\n".join(msg) + "\n", c=3) + self.log("\n\n".join(ta) + "\n", c=3) - vfs.histtab = {v.realpath: v.histpath for v in vfs.all_vols.values()} + vfs.histtab = {zv.realpath: zv.histpath for zv in vfs.all_vols.values()} for vol in vfs.all_vols.values(): lim = Lim() @@ -846,30 +924,30 @@ class AuthSrv(object): use = True lim.nosub = True - v = vol.flags.get("sz") - if v: + zs = vol.flags.get("sz") + if zs: use = True - lim.smin, lim.smax = [unhumanize(x) for x in v.split("-")] + lim.smin, lim.smax = [unhumanize(x) for x in zs.split("-")] - v = vol.flags.get("rotn") - if v: + zs = vol.flags.get("rotn") + if zs: use = True - lim.rotn, lim.rotl = [int(x) for x in v.split(",")] + lim.rotn, lim.rotl = [int(x) for x in zs.split(",")] - v = vol.flags.get("rotf") - if v: + zs = vol.flags.get("rotf") + if zs: use = True - lim.set_rotf(v) + lim.set_rotf(zs) - v = vol.flags.get("maxn") - if v: + zs = vol.flags.get("maxn") + if zs: use = True - lim.nmax, lim.nwin = [int(x) for x in v.split(",")] + lim.nmax, lim.nwin = [int(x) for x in zs.split(",")] - v = vol.flags.get("maxb") - if v: + zs = vol.flags.get("maxb") + if zs: use = True - lim.bmax, lim.bwin = [unhumanize(x) for x in v.split(",")] + lim.bmax, lim.bwin = [unhumanize(x) for x in zs.split(",")] if use: vol.lim = lim @@ -1005,8 +1083,8 @@ class AuthSrv(object): for mtp in local_only_mtp: if mtp not in local_mte: - m = 'volume "/{}" defines metadata tag "{}", but doesnt use it in "-mte" (or with "cmte" in its volume-flags)' - self.log(m.format(vol.vpath, mtp), 1) + t = 'volume "/{}" defines metadata tag "{}", but doesnt use it in "-mte" (or with "cmte" in its volume-flags)' + self.log(t.format(vol.vpath, mtp), 1) errors = True tags = self.args.mtp or [] @@ -1014,8 +1092,8 @@ class AuthSrv(object): tags = [y for x in tags for y in x.split(",")] for mtp in tags: if mtp not in all_mte: - m = 'metadata tag "{}" is defined by "-mtm" or "-mtp", but is not used by "-mte" (or by any "cmte" volume-flag)' - self.log(m.format(mtp), 1) + t = 'metadata tag "{}" is defined by "-mtm" or "-mtp", but is not used by "-mte" (or by any "cmte" volume-flag)' + self.log(t.format(mtp), 1) errors = True if errors: @@ -1023,12 +1101,12 @@ class AuthSrv(object): vfs.bubble_flags() - m = "volumes and permissions:\n" - for v in vfs.all_vols.values(): + t = "volumes and permissions:\n" + for zv in vfs.all_vols.values(): if not self.warn_anonwrite: break - m += '\n\033[36m"/{}" \033[33m{}\033[0m'.format(v.vpath, v.realpath) + t += '\n\033[36m"/{}" \033[33m{}\033[0m'.format(zv.vpath, zv.realpath) for txt, attr in [ [" read", "uread"], [" write", "uwrite"], @@ -1036,21 +1114,21 @@ class AuthSrv(object): ["delete", "udel"], [" get", "uget"], ]: - u = list(sorted(getattr(v.axs, attr).keys())) + u = list(sorted(getattr(zv.axs, attr))) u = ", ".join("\033[35meverybody\033[0m" if x == "*" else x for x in u) u = u if u else "\033[36m--none--\033[0m" - m += "\n| {}: {}".format(txt, u) - m += "\n" + t += "\n| {}: {}".format(txt, u) + t += "\n" if self.warn_anonwrite and not self.args.no_voldump: - self.log(m) + self.log(t) try: - v, _ = vfs.get("/", "*", False, True) - if self.warn_anonwrite and os.getcwd() == v.realpath: + zv, _ = vfs.get("/", "*", False, True) + if self.warn_anonwrite and os.getcwd() == zv.realpath: self.warn_anonwrite = False - msg = "anyone can read/write the current directory: {}\n" - self.log(msg.format(v.realpath), c=1) + t = "anyone can read/write the current directory: {}\n" + self.log(t.format(zv.realpath), c=1) except Pebkac: self.warn_anonwrite = True @@ -1064,19 +1142,19 @@ class AuthSrv(object): if pwds: self.re_pwd = re.compile("=(" + "|".join(pwds) + ")([]&; ]|$)") - def dbg_ls(self): + def dbg_ls(self) -> None: users = self.args.ls - vols = "*" - flags = [] + vol = "*" + flags: list[str] = [] try: - users, vols = users.split(",", 1) + users, vol = users.split(",", 1) except: pass try: - vols, flags = vols.split(",", 1) - flags = flags.split(",") + vol, zf = vol.split(",", 1) + flags = zf.split(",") except: pass @@ -1089,23 +1167,23 @@ class AuthSrv(object): if u not in self.acct and u != "*": raise Exception("user not found: " + u) - if vols == "*": + if vol == "*": vols = ["/" + x for x in self.vfs.all_vols] else: - vols = [vols] + vols = [vol] - for v in vols: - if not v.startswith("/"): + for zs in vols: + if not zs.startswith("/"): raise Exception("volumes must start with /") - if v[1:] not in self.vfs.all_vols: - raise Exception("volume not found: " + v) + if zs[1:] not in self.vfs.all_vols: + raise Exception("volume not found: " + zs) - self.log({"users": users, "vols": vols, "flags": flags}) - m = "/{}: read({}) write({}) move({}) del({}) get({})" - for k, v in self.vfs.all_vols.items(): - vc = v.axs - self.log(m.format(k, vc.uread, vc.uwrite, vc.umove, vc.udel, vc.uget)) + self.log(str({"users": users, "vols": vols, "flags": flags})) + t = "/{}: read({}) write({}) move({}) del({}) get({})" + for k, zv in self.vfs.all_vols.items(): + vc = zv.axs + self.log(t.format(k, vc.uread, vc.uwrite, vc.umove, vc.udel, vc.uget)) flag_v = "v" in flags flag_ln = "ln" in flags @@ -1136,12 +1214,14 @@ class AuthSrv(object): False, False, ) - for _, _, vpath, apath, files, dirs, _ in g: - fnames = [n[0] for n in files] - vpaths = [vpath + "/" + n for n in fnames] if vpath else fnames - vpaths = [vtop + x for x in vpaths] + for _, _, vpath, apath, files1, dirs, _ in g: + fnames = [n[0] for n in files1] + zsl = [vpath + "/" + n for n in fnames] if vpath else fnames + vpaths = [vtop + x for x in zsl] apaths = [os.path.join(apath, n) for n in fnames] - files = [[vpath + "/", apath + os.sep]] + list(zip(vpaths, apaths)) + files = [(vpath + "/", apath + os.sep)] + list( + [(zs1, zs2) for zs1, zs2 in zip(vpaths, apaths)] + ) if flag_ln: files = [x for x in files if not x[1].startswith(safeabs)] @@ -1152,21 +1232,23 @@ class AuthSrv(object): if not files: continue elif flag_v: - msg = [""] + [ + ta = [""] + [ '# user "{}", vpath "{}"\n{}'.format(u, vp, ap) for vp, ap in files ] else: - msg = ["user {}, vol {}: {} =>".format(u, vtop, files[0][0])] - msg += [x[1] for x in files] + ta = ["user {}, vol {}: {} =>".format(u, vtop, files[0][0])] + ta += [x[1] for x in files] - self.log("\n".join(msg)) + self.log("\n".join(ta)) if bads: self.log("\n ".join(["found symlinks leaving volume:"] + bads)) if bads and flag_p: - raise Exception("found symlink leaving volume, and strict is set") + raise Exception( + "\033[31m\n [--ls] found a safety issue and prevented startup:\n found symlinks leaving volume, and strict is set\n\033[0m" + ) if not flag_r: sys.exit(0) diff --git a/copyparty/bos/bos.py b/copyparty/bos/bos.py index d5e003cf..617545af 100644 --- a/copyparty/bos/bos.py +++ b/copyparty/bos/bos.py @@ -2,23 +2,30 @@ from __future__ import print_function, unicode_literals import os -from ..util import fsenc, fsdec, SYMTIME + +from ..util import SYMTIME, fsdec, fsenc from . import path +try: + from typing import Optional +except: + pass + +_ = (path,) # grep -hRiE '(^|[^a-zA-Z_\.-])os\.' . | gsed -r 's/ /\n/g;s/\(/(\n/g' | grep -hRiE '(^|[^a-zA-Z_\.-])os\.' | sort | uniq -c # printf 'os\.(%s)' "$(grep ^def bos/__init__.py | gsed -r 's/^def //;s/\(.*//' | tr '\n' '|' | gsed -r 's/.$//')" -def chmod(p, mode): +def chmod(p: str, mode: int) -> None: return os.chmod(fsenc(p), mode) -def listdir(p="."): +def listdir(p: str = ".") -> list[str]: return [fsdec(x) for x in os.listdir(fsenc(p))] -def makedirs(name, mode=0o755, exist_ok=True): +def makedirs(name: str, mode: int = 0o755, exist_ok: bool = True) -> None: bname = fsenc(name) try: os.makedirs(bname, mode) @@ -27,31 +34,33 @@ def makedirs(name, mode=0o755, exist_ok=True): raise -def mkdir(p, mode=0o755): +def mkdir(p: str, mode: int = 0o755) -> None: return os.mkdir(fsenc(p), mode) -def rename(src, dst): +def rename(src: str, dst: str) -> None: return os.rename(fsenc(src), fsenc(dst)) -def replace(src, dst): +def replace(src: str, dst: str) -> None: return os.replace(fsenc(src), fsenc(dst)) -def rmdir(p): +def rmdir(p: str) -> None: return os.rmdir(fsenc(p)) -def stat(p): +def stat(p: str) -> os.stat_result: return os.stat(fsenc(p)) -def unlink(p): +def unlink(p: str) -> None: return os.unlink(fsenc(p)) -def utime(p, times=None, follow_symlinks=True): +def utime( + p: str, times: Optional[tuple[float, float]] = None, follow_symlinks: bool = True +) -> None: if SYMTIME: return os.utime(fsenc(p), times, follow_symlinks=follow_symlinks) else: @@ -60,7 +69,7 @@ def utime(p, times=None, follow_symlinks=True): if hasattr(os, "lstat"): - def lstat(p): + def lstat(p: str) -> os.stat_result: return os.lstat(fsenc(p)) else: diff --git a/copyparty/bos/path.py b/copyparty/bos/path.py index 066453b0..c5769d84 100644 --- a/copyparty/bos/path.py +++ b/copyparty/bos/path.py @@ -2,43 +2,44 @@ from __future__ import print_function, unicode_literals import os -from ..util import fsenc, fsdec, SYMTIME + +from ..util import SYMTIME, fsdec, fsenc -def abspath(p): +def abspath(p: str) -> str: return fsdec(os.path.abspath(fsenc(p))) -def exists(p): +def exists(p: str) -> bool: return os.path.exists(fsenc(p)) -def getmtime(p, follow_symlinks=True): +def getmtime(p: str, follow_symlinks: bool = True) -> float: if not follow_symlinks and SYMTIME: return os.lstat(fsenc(p)).st_mtime else: return os.path.getmtime(fsenc(p)) -def getsize(p): +def getsize(p: str) -> int: return os.path.getsize(fsenc(p)) -def isfile(p): +def isfile(p: str) -> bool: return os.path.isfile(fsenc(p)) -def isdir(p): +def isdir(p: str) -> bool: return os.path.isdir(fsenc(p)) -def islink(p): +def islink(p: str) -> bool: return os.path.islink(fsenc(p)) -def lexists(p): +def lexists(p: str) -> bool: return os.path.lexists(fsenc(p)) -def realpath(p): +def realpath(p: str) -> str: return fsdec(os.path.realpath(fsenc(p))) diff --git a/copyparty/broker_mp.py b/copyparty/broker_mp.py index df73658d..c7bfed70 100644 --- a/copyparty/broker_mp.py +++ b/copyparty/broker_mp.py @@ -1,37 +1,56 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import time import threading +import time -from .broker_util import try_exec +import queue + +from .__init__ import TYPE_CHECKING from .broker_mpw import MpWorker +from .broker_util import try_exec from .util import mp +if TYPE_CHECKING: + from .svchub import SvcHub + +try: + from typing import Any +except: + pass + + +class MProcess(mp.Process): + def __init__( + self, + q_pend: queue.Queue[tuple[int, str, list[Any]]], + q_yield: queue.Queue[tuple[int, str, list[Any]]], + target: Any, + args: Any, + ) -> None: + super(MProcess, self).__init__(target=target, args=args) + self.q_pend = q_pend + self.q_yield = q_yield + class BrokerMp(object): """external api; manages MpWorkers""" - def __init__(self, hub): + def __init__(self, hub: "SvcHub") -> None: self.hub = hub self.log = hub.log self.args = hub.args self.procs = [] - self.retpend = {} - self.retpend_mutex = threading.Lock() self.mutex = threading.Lock() self.num_workers = self.args.j or mp.cpu_count() self.log("broker", "booting {} subprocesses".format(self.num_workers)) for n in range(1, self.num_workers + 1): - q_pend = mp.Queue(1) - q_yield = mp.Queue(64) + q_pend: queue.Queue[tuple[int, str, list[Any]]] = mp.Queue(1) + q_yield: queue.Queue[tuple[int, str, list[Any]]] = mp.Queue(64) - proc = mp.Process(target=MpWorker, args=(q_pend, q_yield, self.args, n)) - proc.q_pend = q_pend - proc.q_yield = q_yield - proc.clients = {} + proc = MProcess(q_pend, q_yield, MpWorker, (q_pend, q_yield, self.args, n)) thr = threading.Thread( target=self.collector, args=(proc,), name="mp-sink-{}".format(n) @@ -42,11 +61,11 @@ class BrokerMp(object): self.procs.append(proc) proc.start() - def shutdown(self): + def shutdown(self) -> None: self.log("broker", "shutting down") for n, proc in enumerate(self.procs): thr = threading.Thread( - target=proc.q_pend.put([0, "shutdown", []]), + target=proc.q_pend.put((0, "shutdown", [])), name="mp-shutdown-{}-{}".format(n, len(self.procs)), ) thr.start() @@ -62,12 +81,12 @@ class BrokerMp(object): procs.pop() - def reload(self): + def reload(self) -> None: self.log("broker", "reloading") for _, proc in enumerate(self.procs): - proc.q_pend.put([0, "reload", []]) + proc.q_pend.put((0, "reload", [])) - def collector(self, proc): + def collector(self, proc: MProcess) -> None: """receive message from hub in other process""" while True: msg = proc.q_yield.get() @@ -78,10 +97,7 @@ class BrokerMp(object): elif dest == "retq": # response from previous ipc call - with self.retpend_mutex: - retq = self.retpend.pop(retq_id) - - retq.put(args) + raise Exception("invalid broker_mp usage") else: # new ipc invoking managed service in hub @@ -93,9 +109,9 @@ class BrokerMp(object): rv = try_exec(retq_id, obj, *args) if retq_id: - proc.q_pend.put([retq_id, "retq", rv]) + proc.q_pend.put((retq_id, "retq", rv)) - def put(self, want_retval, dest, *args): + def say(self, dest: str, *args: Any) -> None: """ send message to non-hub component in other process, returns a Queue object which eventually contains the response if want_retval @@ -103,7 +119,7 @@ class BrokerMp(object): """ if dest == "listen": for p in self.procs: - p.q_pend.put([0, dest, [args[0], len(self.procs)]]) + p.q_pend.put((0, dest, [args[0], len(self.procs)])) elif dest == "cb_httpsrv_up": self.hub.cb_httpsrv_up() diff --git a/copyparty/broker_mpw.py b/copyparty/broker_mpw.py index c4a1054c..4dbec5b1 100644 --- a/copyparty/broker_mpw.py +++ b/copyparty/broker_mpw.py @@ -1,20 +1,38 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import sys +import argparse import signal +import sys import threading -from .broker_util import ExceptionalQueue +import queue + +from .authsrv import AuthSrv +from .broker_util import BrokerCli, ExceptionalQueue from .httpsrv import HttpSrv from .util import FAKE_MP -from .authsrv import AuthSrv + +try: + from types import FrameType + + from typing import Any, Optional, Union +except: + pass -class MpWorker(object): +class MpWorker(BrokerCli): """one single mp instance""" - def __init__(self, q_pend, q_yield, args, n): + def __init__( + self, + q_pend: queue.Queue[tuple[int, str, list[Any]]], + q_yield: queue.Queue[tuple[int, str, list[Any]]], + args: argparse.Namespace, + n: int, + ) -> None: + super(MpWorker, self).__init__() + self.q_pend = q_pend self.q_yield = q_yield self.args = args @@ -22,7 +40,7 @@ class MpWorker(object): self.log = self._log_disabled if args.q and not args.lo else self._log_enabled - self.retpend = {} + self.retpend: dict[int, Any] = {} self.retpend_mutex = threading.Lock() self.mutex = threading.Lock() @@ -45,20 +63,20 @@ class MpWorker(object): thr.start() thr.join() - def signal_handler(self, sig, frame): + def signal_handler(self, sig: Optional[int], frame: Optional[FrameType]) -> None: # print('k') pass - def _log_enabled(self, src, msg, c=0): - self.q_yield.put([0, "log", [src, msg, c]]) + def _log_enabled(self, src: str, msg: str, c: Union[int, str] = 0) -> None: + self.q_yield.put((0, "log", [src, msg, c])) - def _log_disabled(self, src, msg, c=0): + def _log_disabled(self, src: str, msg: str, c: Union[int, str] = 0) -> None: pass - def logw(self, msg, c=0): + def logw(self, msg: str, c: Union[int, str] = 0) -> None: self.log("mp{}".format(self.n), msg, c) - def main(self): + def main(self) -> None: while True: retq_id, dest, args = self.q_pend.get() @@ -87,15 +105,14 @@ class MpWorker(object): else: raise Exception("what is " + str(dest)) - def put(self, want_retval, dest, *args): - if want_retval: - retq = ExceptionalQueue(1) - retq_id = id(retq) - with self.retpend_mutex: - self.retpend[retq_id] = retq - else: - retq = None - retq_id = 0 + def ask(self, dest: str, *args: Any) -> ExceptionalQueue: + retq = ExceptionalQueue(1) + retq_id = id(retq) + with self.retpend_mutex: + self.retpend[retq_id] = retq - self.q_yield.put([retq_id, dest, args]) + self.q_yield.put((retq_id, dest, list(args))) return retq + + def say(self, dest: str, *args: Any) -> None: + self.q_yield.put((0, dest, list(args))) diff --git a/copyparty/broker_thr.py b/copyparty/broker_thr.py index 1c7a1abf..51c25d41 100644 --- a/copyparty/broker_thr.py +++ b/copyparty/broker_thr.py @@ -3,14 +3,25 @@ from __future__ import print_function, unicode_literals import threading +from .__init__ import TYPE_CHECKING +from .broker_util import BrokerCli, ExceptionalQueue, try_exec from .httpsrv import HttpSrv -from .broker_util import ExceptionalQueue, try_exec + +if TYPE_CHECKING: + from .svchub import SvcHub + +try: + from typing import Any +except: + pass -class BrokerThr(object): +class BrokerThr(BrokerCli): """external api; behaves like BrokerMP but using plain threads""" - def __init__(self, hub): + def __init__(self, hub: "SvcHub") -> None: + super(BrokerThr, self).__init__() + self.hub = hub self.log = hub.log self.args = hub.args @@ -23,29 +34,35 @@ class BrokerThr(object): self.httpsrv = HttpSrv(self, None) self.reload = self.noop - def shutdown(self): + def shutdown(self) -> None: # self.log("broker", "shutting down") self.httpsrv.shutdown() - def noop(self): + def noop(self) -> None: pass - def put(self, want_retval, dest, *args): + def ask(self, dest: str, *args: Any) -> ExceptionalQueue: + + # new ipc invoking managed service in hub + obj = self.hub + for node in dest.split("."): + obj = getattr(obj, node) + + rv = try_exec(True, obj, *args) + + # pretend we're broker_mp + retq = ExceptionalQueue(1) + retq.put(rv) + return retq + + def say(self, dest: str, *args: Any) -> None: if dest == "listen": self.httpsrv.listen(args[0], 1) + return - else: - # new ipc invoking managed service in hub - obj = self.hub - for node in dest.split("."): - obj = getattr(obj, node) + # new ipc invoking managed service in hub + obj = self.hub + for node in dest.split("."): + obj = getattr(obj, node) - # TODO will deadlock if dest performs another ipc - rv = try_exec(want_retval, obj, *args) - if not want_retval: - return - - # pretend we're broker_mp - retq = ExceptionalQueue(1) - retq.put(rv) - return retq + try_exec(False, obj, *args) diff --git a/copyparty/broker_util.py b/copyparty/broker_util.py index 896cba2c..b0d44575 100644 --- a/copyparty/broker_util.py +++ b/copyparty/broker_util.py @@ -1,14 +1,28 @@ # coding: utf-8 from __future__ import print_function, unicode_literals - +import argparse import traceback -from .util import Pebkac, Queue +from queue import Queue + +from .__init__ import TYPE_CHECKING +from .authsrv import AuthSrv +from .util import Pebkac + +try: + from typing import Any, Optional, Union + + from .util import RootLogger +except: + pass + +if TYPE_CHECKING: + from .httpsrv import HttpSrv class ExceptionalQueue(Queue, object): - def get(self, block=True, timeout=None): + def get(self, block: bool = True, timeout: Optional[float] = None) -> Any: rv = super(ExceptionalQueue, self).get(block, timeout) if isinstance(rv, list): @@ -21,7 +35,26 @@ class ExceptionalQueue(Queue, object): return rv -def try_exec(want_retval, func, *args): +class BrokerCli(object): + """ + helps mypy understand httpsrv.broker but still fails a few levels deeper, + for example resolving httpconn.* in httpcli -- see lines tagged #mypy404 + """ + + def __init__(self) -> None: + self.log: RootLogger = None + self.args: argparse.Namespace = None + self.asrv: AuthSrv = None + self.httpsrv: "HttpSrv" = None + + def ask(self, dest: str, *args: Any) -> ExceptionalQueue: + return ExceptionalQueue(1) + + def say(self, dest: str, *args: Any) -> None: + pass + + +def try_exec(want_retval: Union[bool, int], func: Any, *args: list[Any]) -> Any: try: return func(*args) diff --git a/copyparty/ftpd.py b/copyparty/ftpd.py index 828c0db6..a47586de 100644 --- a/copyparty/ftpd.py +++ b/copyparty/ftpd.py @@ -1,23 +1,23 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import os -import sys -import stat -import time -import logging import argparse +import logging +import os +import stat +import sys import threading +import time -from pyftpdlib.authorizers import DummyAuthorizer, AuthenticationFailed +from pyftpdlib.authorizers import AuthenticationFailed, DummyAuthorizer from pyftpdlib.filesystems import AbstractedFS, FilesystemError from pyftpdlib.handlers import FTPHandler -from pyftpdlib.servers import FTPServer from pyftpdlib.log import config_logging +from pyftpdlib.servers import FTPServer -from .__init__ import E, PY2 -from .util import Pebkac, fsenc, exclude_dotfiles +from .__init__ import PY2, TYPE_CHECKING, E from .bos import bos +from .util import Pebkac, exclude_dotfiles, fsenc try: from pyftpdlib.ioloop import IOLoop @@ -28,58 +28,63 @@ except ImportError: from pyftpdlib.ioloop import IOLoop -try: - from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .svchub import SvcHub - if TYPE_CHECKING: - from .svchub import SvcHub -except ImportError: +try: + import typing + from typing import Any, Optional +except: pass class FtpAuth(DummyAuthorizer): - def __init__(self): + def __init__(self, hub: "SvcHub") -> None: super(FtpAuth, self).__init__() - self.hub = None # type: SvcHub + self.hub = hub - def validate_authentication(self, username, password, handler): + def validate_authentication( + self, username: str, password: str, handler: Any + ) -> None: asrv = self.hub.asrv if username == "anonymous": password = "" uname = "*" if password: - uname = asrv.iacct.get(password, None) + uname = asrv.iacct.get(password, "") handler.username = uname if password and not uname: raise AuthenticationFailed("Authentication failed.") - def get_home_dir(self, username): + def get_home_dir(self, username: str) -> str: return "/" - def has_user(self, username): + def has_user(self, username: str) -> bool: asrv = self.hub.asrv return username in asrv.acct - def has_perm(self, username, perm, path=None): + def has_perm(self, username: str, perm: int, path: Optional[str] = None) -> bool: return True # handled at filesystem layer - def get_perms(self, username): + def get_perms(self, username: str) -> str: return "elradfmwMT" - def get_msg_login(self, username): + def get_msg_login(self, username: str) -> str: return "sup {}".format(username) - def get_msg_quit(self, username): + def get_msg_quit(self, username: str) -> str: return "cya" class FtpFs(AbstractedFS): - def __init__(self, root, cmd_channel): + def __init__( + self, root: str, cmd_channel: Any + ) -> None: # pylint: disable=super-init-not-called self.h = self.cmd_channel = cmd_channel # type: FTPHandler - self.hub = cmd_channel.hub # type: SvcHub + self.hub: "SvcHub" = cmd_channel.hub self.args = cmd_channel.args self.uname = self.hub.asrv.iacct.get(cmd_channel.password, "*") @@ -90,7 +95,14 @@ class FtpFs(AbstractedFS): self.listdirinfo = self.listdir self.chdir(".") - def v2a(self, vpath, r=False, w=False, m=False, d=False): + def v2a( + self, + vpath: str, + r: bool = False, + w: bool = False, + m: bool = False, + d: bool = False, + ) -> str: try: vpath = vpath.replace("\\", "/").lstrip("/") vfs, rem = self.hub.asrv.vfs.get(vpath, self.uname, r, w, m, d) @@ -101,25 +113,32 @@ class FtpFs(AbstractedFS): except Pebkac as ex: raise FilesystemError(str(ex)) - def rv2a(self, vpath, r=False, w=False, m=False, d=False): + def rv2a( + self, + vpath: str, + r: bool = False, + w: bool = False, + m: bool = False, + d: bool = False, + ) -> str: return self.v2a(os.path.join(self.cwd, vpath), r, w, m, d) - def ftp2fs(self, ftppath): + def ftp2fs(self, ftppath: str) -> str: # return self.v2a(ftppath) return ftppath # self.cwd must be vpath - def fs2ftp(self, fspath): + def fs2ftp(self, fspath: str) -> str: # raise NotImplementedError() return fspath - def validpath(self, path): + def validpath(self, path: str) -> bool: if "/.hist/" in path: if "/up2k." in path or path.endswith("/dir.txt"): raise FilesystemError("access to this file is forbidden") return True - def open(self, filename, mode): + def open(self, filename: str, mode: str) -> typing.IO[Any]: r = "r" in mode w = "w" in mode or "a" in mode or "+" in mode @@ -130,24 +149,24 @@ class FtpFs(AbstractedFS): self.validpath(ap) return open(fsenc(ap), mode) - def chdir(self, path): + def chdir(self, path: str) -> None: self.cwd = join(self.cwd, path) x = self.hub.asrv.vfs.can_access(self.cwd.lstrip("/"), self.h.username) self.can_read, self.can_write, self.can_move, self.can_delete, self.can_get = x - def mkdir(self, path): + def mkdir(self, path: str) -> None: ap = self.rv2a(path, w=True) bos.mkdir(ap) - def listdir(self, path): + def listdir(self, path: str) -> list[str]: vpath = join(self.cwd, path).lstrip("/") try: vfs, rem = self.hub.asrv.vfs.get(vpath, self.uname, True, False) - fsroot, vfs_ls, vfs_virt = vfs.ls( + fsroot, vfs_ls1, vfs_virt = vfs.ls( rem, self.uname, not self.args.no_scandir, [[True], [False, True]] ) - vfs_ls = [x[0] for x in vfs_ls] + vfs_ls = [x[0] for x in vfs_ls1] vfs_ls.extend(vfs_virt.keys()) if not self.args.ed: @@ -164,11 +183,11 @@ class FtpFs(AbstractedFS): r = {x.split("/")[0]: 1 for x in self.hub.asrv.vfs.all_vols.keys()} return list(sorted(list(r.keys()))) - def rmdir(self, path): + def rmdir(self, path: str) -> None: ap = self.rv2a(path, d=True) bos.rmdir(ap) - def remove(self, path): + def remove(self, path: str) -> None: if self.args.no_del: raise FilesystemError("the delete feature is disabled in server config") @@ -178,13 +197,13 @@ class FtpFs(AbstractedFS): except Exception as ex: raise FilesystemError(str(ex)) - def rename(self, src, dst): + def rename(self, src: str, dst: str) -> None: if not self.can_move: raise FilesystemError("not allowed for user " + self.h.username) if self.args.no_mv: - m = "the rename/move feature is disabled in server config" - raise FilesystemError(m) + t = "the rename/move feature is disabled in server config" + raise FilesystemError(t) svp = join(self.cwd, src).lstrip("/") dvp = join(self.cwd, dst).lstrip("/") @@ -193,10 +212,10 @@ class FtpFs(AbstractedFS): except Exception as ex: raise FilesystemError(str(ex)) - def chmod(self, path, mode): + def chmod(self, path: str, mode: str) -> None: pass - def stat(self, path): + def stat(self, path: str) -> os.stat_result: try: ap = self.rv2a(path, r=True) return bos.stat(ap) @@ -208,59 +227,59 @@ class FtpFs(AbstractedFS): return st - def utime(self, path, timeval): + def utime(self, path: str, timeval: float) -> None: ap = self.rv2a(path, w=True) return bos.utime(ap, (timeval, timeval)) - def lstat(self, path): + def lstat(self, path: str) -> os.stat_result: ap = self.rv2a(path) return bos.lstat(ap) - def isfile(self, path): + def isfile(self, path: str) -> bool: st = self.stat(path) return stat.S_ISREG(st.st_mode) - def islink(self, path): + def islink(self, path: str) -> bool: ap = self.rv2a(path) return bos.path.islink(ap) - def isdir(self, path): + def isdir(self, path: str) -> bool: try: st = self.stat(path) return stat.S_ISDIR(st.st_mode) except: return True - def getsize(self, path): + def getsize(self, path: str) -> int: ap = self.rv2a(path) return bos.path.getsize(ap) - def getmtime(self, path): + def getmtime(self, path: str) -> float: ap = self.rv2a(path) return bos.path.getmtime(ap) - def realpath(self, path): + def realpath(self, path: str) -> str: return path - def lexists(self, path): + def lexists(self, path: str) -> bool: ap = self.rv2a(path) return bos.path.lexists(ap) - def get_user_by_uid(self, uid): + def get_user_by_uid(self, uid: int) -> str: return "root" - def get_group_by_uid(self, gid): + def get_group_by_uid(self, gid: int) -> str: return "root" class FtpHandler(FTPHandler): abstracted_fs = FtpFs - hub = None # type: SvcHub - args = None # type: argparse.Namespace + hub: "SvcHub" = None + args: argparse.Namespace = None - def __init__(self, conn, server, ioloop=None): - self.hub = FtpHandler.hub # type: SvcHub - self.args = FtpHandler.args # type: argparse.Namespace + def __init__(self, conn: Any, server: Any, ioloop: Any = None) -> None: + self.hub: "SvcHub" = FtpHandler.hub + self.args: argparse.Namespace = FtpHandler.args if PY2: FTPHandler.__init__(self, conn, server, ioloop) @@ -268,9 +287,10 @@ class FtpHandler(FTPHandler): super(FtpHandler, self).__init__(conn, server, ioloop) # abspath->vpath mapping to resolve log_transfer paths - self.vfs_map = {} + self.vfs_map: dict[str, str] = {} - def ftp_STOR(self, file, mode="w"): + def ftp_STOR(self, file: str, mode: str = "w") -> Any: + # Optional[str] vp = join(self.fs.cwd, file).lstrip("/") ap = self.fs.v2a(vp) self.vfs_map[ap] = vp @@ -279,7 +299,16 @@ class FtpHandler(FTPHandler): # print("ftp_STOR: {} {} OK".format(vp, mode)) return ret - def log_transfer(self, cmd, filename, receive, completed, elapsed, bytes): + def log_transfer( + self, + cmd: str, + filename: bytes, + receive: bool, + completed: bool, + elapsed: float, + bytes: int, + ) -> Any: + # None ap = filename.decode("utf-8", "replace") vp = self.vfs_map.pop(ap, None) # print("xfer_end: {} => {}".format(ap, vp)) @@ -312,7 +341,7 @@ except: class Ftpd(object): - def __init__(self, hub): + def __init__(self, hub: "SvcHub") -> None: self.hub = hub self.args = hub.args @@ -321,24 +350,23 @@ class Ftpd(object): hs.append([FtpHandler, self.args.ftp]) if self.args.ftps: try: - h = SftpHandler + h1 = SftpHandler except: - m = "\nftps requires pyopenssl;\nplease run the following:\n\n {} -m pip install --user pyopenssl\n" - print(m.format(sys.executable)) + t = "\nftps requires pyopenssl;\nplease run the following:\n\n {} -m pip install --user pyopenssl\n" + print(t.format(sys.executable)) sys.exit(1) - h.certfile = os.path.join(E.cfg, "cert.pem") - h.tls_control_required = True - h.tls_data_required = True + h1.certfile = os.path.join(E.cfg, "cert.pem") + h1.tls_control_required = True + h1.tls_data_required = True - hs.append([h, self.args.ftps]) + hs.append([h1, self.args.ftps]) - for h in hs: - h, lp = h - h.hub = hub - h.args = hub.args - h.authorizer = FtpAuth() - h.authorizer.hub = hub + for h_lp in hs: + h2, lp = h_lp + h2.hub = hub + h2.args = hub.args + h2.authorizer = FtpAuth(hub) if self.args.ftp_pr: p1, p2 = [int(x) for x in self.args.ftp_pr.split("-")] @@ -350,10 +378,10 @@ class Ftpd(object): else: p1 += d + 1 - h.passive_ports = list(range(p1, p2 + 1)) + h2.passive_ports = list(range(p1, p2 + 1)) if self.args.ftp_nat: - h.masquerade_address = self.args.ftp_nat + h2.masquerade_address = self.args.ftp_nat if self.args.ftp_dbg: config_logging(level=logging.DEBUG) @@ -363,11 +391,11 @@ class Ftpd(object): for h, lp in hs: FTPServer((ip, int(lp)), h, ioloop) - t = threading.Thread(target=ioloop.loop) - t.daemon = True - t.start() + thr = threading.Thread(target=ioloop.loop) + thr.daemon = True + thr.start() -def join(p1, p2): +def join(p1: str, p2: str) -> str: w = os.path.join(p1, p2.replace("\\", "/")) return os.path.normpath(w).replace("\\", "/") diff --git a/copyparty/httpcli.py b/copyparty/httpcli.py index 5368bf62..cfc2bbf7 100644 --- a/copyparty/httpcli.py +++ b/copyparty/httpcli.py @@ -1,18 +1,23 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import os -import stat -import gzip -import time -import copy -import json +import argparse # typechk import base64 -import string +import calendar +import copy +import gzip +import json +import os +import re import socket +import stat +import string +import threading # typechk +import time from datetime import datetime from operator import itemgetter -import calendar + +import jinja2 # typechk try: import lzma @@ -24,13 +29,62 @@ try: except: pass -from .__init__ import E, PY2, WINDOWS, ANYWIN, unicode -from .util import * # noqa # pylint: disable=unused-wildcard-import +from .__init__ import ANYWIN, PY2, TYPE_CHECKING, WINDOWS, E, unicode +from .authsrv import VFS # typechk from .bos import bos -from .authsrv import AuthSrv -from .szip import StreamZip from .star import StreamTar +from .sutil import StreamArc # typechk +from .szip import StreamZip +from .util import ( + HTTP_TS_FMT, + HTTPCODE, + META_NOBOTS, + MultipartParser, + Pebkac, + UnrecvEOF, + alltrace, + exclude_dotfiles, + fsenc, + gen_filekey, + gencookie, + get_spd, + guess_mime, + gzip_orig_sz, + hashcopy, + html_bescape, + html_escape, + http_ts, + humansize, + min_ex, + quotep, + read_header, + read_socket, + read_socket_chunked, + read_socket_unbounded, + relchk, + ren_open, + s3enc, + sanitize_fn, + sendfile_kern, + sendfile_py, + undot, + unescape_cookie, + unquote, + unquotep, + vol_san, + vsplit, + yieldfile, +) +try: + from typing import Any, Generator, Match, Optional, Pattern, Type, Union +except: + pass + +if TYPE_CHECKING: + from .httpconn import HttpConn + +_ = (argparse, threading) NO_CACHE = {"Cache-Control": "no-cache"} @@ -40,27 +94,60 @@ class HttpCli(object): Spawned by HttpConn to process one http transaction """ - def __init__(self, conn): + def __init__(self, conn: "HttpConn") -> None: + assert conn.sr + self.t0 = time.time() self.conn = conn - self.mutex = conn.mutex - self.s = conn.s # type: socket - self.sr = conn.sr # type: Unrecv + self.mutex = conn.mutex # mypy404 + self.s = conn.s + self.sr = conn.sr self.ip = conn.addr[0] - self.addr = conn.addr # type: tuple[str, int] - self.args = conn.args - self.asrv = conn.asrv # type: AuthSrv - self.ico = conn.ico - self.thumbcli = conn.thumbcli - self.u2fh = conn.u2fh - self.log_func = conn.log_func - self.log_src = conn.log_src - self.tls = hasattr(self.s, "cipher") + self.addr: tuple[str, int] = conn.addr + self.args = conn.args # mypy404 + self.asrv = conn.asrv # mypy404 + self.ico = conn.ico # mypy404 + self.thumbcli = conn.thumbcli # mypy404 + self.u2fh = conn.u2fh # mypy404 + self.log_func = conn.log_func # mypy404 + self.log_src = conn.log_src # mypy404 + self.tls: bool = hasattr(self.s, "cipher") + + # placeholders; assigned by run() + self.keepalive = False + self.is_https = False + self.headers: dict[str, str] = {} + self.mode = " " + self.req = " " + self.http_ver = " " + self.ua = " " + self.is_rclone = False + self.is_ancient = False + self.dip = " " + self.ouparam: dict[str, str] = {} + self.uparam: dict[str, str] = {} + self.cookies: dict[str, str] = {} + self.vpath = " " + self.uname = " " + self.rvol = [" "] + self.wvol = [" "] + self.mvol = [" "] + self.dvol = [" "] + self.gvol = [" "] + self.do_log = True + self.can_read = False + self.can_write = False + self.can_move = False + self.can_delete = False + self.can_get = False + # post + self.parser: Optional[MultipartParser] = None + # end placeholders self.bufsz = 1024 * 32 - self.hint = None + self.hint = "" self.trailing_slash = True - self.out_headerlist = [] + self.out_headerlist: list[tuple[str, str]] = [] self.out_headers = { "Access-Control-Allow-Origin": "*", "Cache-Control": "no-store; max-age=0", @@ -71,44 +158,44 @@ class HttpCli(object): self.out_headers["X-Robots-Tag"] = "noindex, nofollow" self.html_head = h - def log(self, msg, c=0): + def log(self, msg: str, c: Union[int, str] = 0) -> None: ptn = self.asrv.re_pwd if ptn and ptn.search(msg): msg = ptn.sub(self.unpwd, msg) self.log_func(self.log_src, msg, c) - def unpwd(self, m): + def unpwd(self, m: Match[str]) -> str: a, b = m.groups() return "=\033[7m {} \033[27m{}".format(self.asrv.iacct[a], b) - def _check_nonfatal(self, ex, post): + def _check_nonfatal(self, ex: Pebkac, post: bool) -> bool: if post: return ex.code < 300 return ex.code < 400 or ex.code in [404, 429] - def _assert_safe_rem(self, rem): + def _assert_safe_rem(self, rem: str) -> None: # sanity check to prevent any disasters if rem.startswith("/") or rem.startswith("../") or "/../" in rem: raise Exception("that was close") - def j2(self, name, **ka): + def j2s(self, name: str, **ka: Any) -> str: tpl = self.conn.hsrv.j2[name] - if ka: - ka["ts"] = self.conn.hsrv.cachebuster() - ka["svcname"] = self.args.doctitle - ka["html_head"] = self.html_head - return tpl.render(**ka) + ka["ts"] = self.conn.hsrv.cachebuster() + ka["svcname"] = self.args.doctitle + ka["html_head"] = self.html_head + return tpl.render(**ka) # type: ignore - return tpl + def j2j(self, name: str) -> jinja2.Template: + return self.conn.hsrv.j2[name] - def run(self): + def run(self) -> bool: """returns true if connection can be reused""" self.keepalive = False self.is_https = False self.headers = {} - self.hint = None + self.hint = "" try: headerlines = read_header(self.sr) if not headerlines: @@ -125,8 +212,8 @@ class HttpCli(object): # normalize incoming headers to lowercase; # outgoing headers however are Correct-Case for header_line in headerlines[1:]: - k, v = header_line.split(":", 1) - self.headers[k.lower()] = v.strip() + k, zs = header_line.split(":", 1) + self.headers[k.lower()] = zs.strip() except: msg = " ]\n#[ ".join(headerlines) raise Pebkac(400, "bad headers:\n#[ " + msg + " ]") @@ -147,23 +234,26 @@ class HttpCli(object): self.is_rclone = self.ua.startswith("rclone/") self.is_ancient = self.ua.startswith("Mozilla/4.") - v = self.headers.get("connection", "").lower() - self.keepalive = not v.startswith("close") and self.http_ver != "HTTP/1.0" - self.is_https = (self.headers.get("x-forwarded-proto", "").lower() == "https" or self.tls) + zs = self.headers.get("connection", "").lower() + self.keepalive = not zs.startswith("close") and self.http_ver != "HTTP/1.0" + self.is_https = ( + self.headers.get("x-forwarded-proto", "").lower() == "https" or self.tls + ) n = self.args.rproxy if n: - v = self.headers.get("x-forwarded-for") - if v and self.conn.addr[0] in ["127.0.0.1", "::1"]: + zso = self.headers.get("x-forwarded-for") + if zso and self.conn.addr[0] in ["127.0.0.1", "::1"]: if n > 0: n -= 1 - vs = v.split(",") + zsl = zso.split(",") try: - self.ip = vs[n].strip() + self.ip = zsl[n].strip() except: - self.ip = vs[0].strip() - self.log("rproxy={} oob x-fwd {}".format(self.args.rproxy, v), c=3) + self.ip = zsl[0].strip() + t = "rproxy={} oob x-fwd {}" + self.log(t.format(self.args.rproxy, zso), c=3) self.log_src = self.conn.set_rproxy(self.ip) @@ -175,9 +265,9 @@ class HttpCli(object): keys = list(sorted(self.headers.keys())) for k in keys: - v = self.headers.get(k) - if v is not None: - self.log("[H] {}: \033[33m[{}]".format(k, v), 6) + zso = self.headers.get(k) + if zso is not None: + self.log("[H] {}: \033[33m[{}]".format(k, zso), 6) if "&" in self.req and "?" not in self.req: self.hint = "did you mean '?' instead of '&'" @@ -193,20 +283,22 @@ class HttpCli(object): vpath = undot(vpath) for k in arglist.split("&"): if "=" in k: - k, v = k.split("=", 1) - uparam[k.lower()] = v.strip() + k, zs = k.split("=", 1) + uparam[k.lower()] = zs.strip() else: - uparam[k.lower()] = False + uparam[k.lower()] = "" - self.ouparam = {k: v for k, v in uparam.items()} + self.ouparam = {k: zs for k, zs in uparam.items()} - cookies = self.headers.get("cookie") or {} - if cookies: - cookies = [x.split("=", 1) for x in cookies.split(";") if "=" in x] - cookies = {k.strip(): unescape_cookie(v) for k, v in cookies} + zso = self.headers.get("cookie") + if zso: + zsll = [x.split("=", 1) for x in zso.split(";") if "=" in x] + cookies = {k.strip(): unescape_cookie(zs) for k, zs in zsll} for kc, ku in [["cppwd", "pw"], ["b", "b"]]: if kc in cookies and ku not in uparam: uparam[ku] = cookies[kc] + else: + cookies = {} if len(uparam) > 10 or len(cookies) > 50: raise Pebkac(400, "u wot m8") @@ -223,22 +315,22 @@ class HttpCli(object): self.log("invalid relpath [{}]".format(self.vpath)) return self.tx_404() and self.keepalive - pwd = None - ba = self.headers.get("authorization") - if ba: + pwd = "" + zso = self.headers.get("authorization") + if zso: try: - ba = ba.split(" ")[1].encode("ascii") - ba = base64.b64decode(ba).decode("utf-8") + zb = zso.split(" ")[1].encode("ascii") + zs = base64.b64decode(zb).decode("utf-8") # try "pwd", "x:pwd", "pwd:x" - for ba in [ba] + ba.split(":", 1)[::-1]: - if self.asrv.iacct.get(ba): - pwd = ba + for zs in [zs] + zs.split(":", 1)[::-1]: + if self.asrv.iacct.get(zs): + pwd = zs break except: pass pwd = uparam.get("pw") or pwd - self.uname = self.asrv.iacct.get(pwd, "*") + self.uname = self.asrv.iacct.get(pwd) or "*" self.rvol = self.asrv.vfs.aread[self.uname] self.wvol = self.asrv.vfs.awrite[self.uname] self.mvol = self.asrv.vfs.amove[self.uname] @@ -249,12 +341,13 @@ class HttpCli(object): self.out_headerlist.append(("Set-Cookie", self.get_pwd_cookie(pwd)[0])) if self.is_rclone: - uparam["raw"] = False - uparam["dots"] = False - uparam["b"] = False - cookies["b"] = False + uparam["raw"] = "" + uparam["dots"] = "" + uparam["b"] = "" + cookies["b"] = "" - self.do_log = not self.conn.lf_url or not self.conn.lf_url.search(self.req) + ptn: Optional[Pattern[str]] = self.conn.lf_url # mypy404 + self.do_log = not ptn or not ptn.search(self.req) x = self.asrv.vfs.can_access(self.vpath, self.uname) self.can_read, self.can_write, self.can_move, self.can_delete, self.can_get = x @@ -272,9 +365,10 @@ class HttpCli(object): raise Pebkac(400, 'invalid HTTP mode "{0}"'.format(self.mode)) except Exception as ex: - pex = ex if not hasattr(ex, "code"): pex = Pebkac(500) + else: + pex = ex # type: ignore try: post = self.mode in ["POST", "PUT"] or "content-length" in self.headers @@ -294,7 +388,7 @@ class HttpCli(object): except Pebkac: return False - def permit_caching(self): + def permit_caching(self) -> None: cache = self.uparam.get("cache") if cache is None: self.out_headers.update(NO_CACHE) @@ -303,11 +397,17 @@ class HttpCli(object): n = "604800" if cache == "i" else cache or "69" self.out_headers["Cache-Control"] = "max-age=" + n - def k304(self): + def k304(self) -> bool: k304 = self.cookies.get("k304") return k304 == "y" or ("; Trident/" in self.ua and not k304) - def send_headers(self, length, status=200, mime=None, headers=None): + def send_headers( + self, + length: Optional[int], + status: int = 200, + mime: Optional[str] = None, + headers: Optional[dict[str, str]] = None, + ) -> None: response = ["{} {} {}".format(self.http_ver, status, HTTPCODE[status])] if length is not None: @@ -327,10 +427,11 @@ class HttpCli(object): if not mime: mime = self.out_headers.get("Content-Type", "text/html; charset=utf-8") + assert mime self.out_headers["Content-Type"] = mime - for k, v in list(self.out_headers.items()) + self.out_headerlist: - response.append("{}: {}".format(k, v)) + for k, zs in list(self.out_headers.items()) + self.out_headerlist: + response.append("{}: {}".format(k, zs)) try: # best practice to separate headers and body into different packets @@ -338,11 +439,19 @@ class HttpCli(object): except: raise Pebkac(400, "client d/c while replying headers") - def reply(self, body, status=200, mime=None, headers=None, volsan=False): + def reply( + self, + body: bytes, + status: int = 200, + mime: Optional[str] = None, + headers: Optional[dict[str, str]] = None, + volsan: bool = False, + ) -> bytes: # TODO something to reply with user-supplied values safely if volsan: - body = vol_san(self.asrv.vfs.all_vols.values(), body) + vols = list(self.asrv.vfs.all_vols.values()) + body = vol_san(vols, body) self.send_headers(len(body), status, mime, headers) @@ -354,17 +463,19 @@ class HttpCli(object): return body - def loud_reply(self, body, *args, **kwargs): + def loud_reply(self, body: str, *args: Any, **kwargs: Any) -> None: if not kwargs.get("mime"): kwargs["mime"] = "text/plain; charset=utf-8" self.log(body.rstrip()) self.reply(body.encode("utf-8") + b"\r\n", *list(args), **kwargs) - def urlq(self, add, rm): + def urlq(self, add: dict[str, str], rm: list[str]) -> str: """ generates url query based on uparam (b, pw, all others) removing anything in rm, adding pairs in add + + also list faster than set until ~20 items """ if self.is_rclone: @@ -372,28 +483,28 @@ class HttpCli(object): cmap = {"pw": "cppwd"} kv = { - k: v - for k, v in self.uparam.items() - if k not in rm and self.cookies.get(cmap.get(k, k)) != v + k: zs + for k, zs in self.uparam.items() + if k not in rm and self.cookies.get(cmap.get(k, k)) != zs } kv.update(add) if not kv: return "" - r = ["{}={}".format(k, quotep(v)) if v else k for k, v in kv.items()] + r = ["{}={}".format(k, quotep(zs)) if zs else k for k, zs in kv.items()] return "?" + "&".join(r) def redirect( self, - vpath, - suf="", - msg="aight", - flavor="go to", - click=True, - status=200, - use302=False, - ): - html = self.j2( + vpath: str, + suf: str = "", + msg: str = "aight", + flavor: str = "go to", + click: bool = True, + status: int = 200, + use302: bool = False, + ) -> bool: + html = self.j2s( "msg", h2='{} /{}'.format( quotep(vpath) + suf, flavor, html_escape(vpath, crlf=True) + suf @@ -407,7 +518,9 @@ class HttpCli(object): else: self.reply(html, status=status) - def handle_get(self): + return True + + def handle_get(self) -> bool: if self.do_log: logmsg = "{:4} {}".format(self.mode, self.req) @@ -434,13 +547,13 @@ class HttpCli(object): self.log("inaccessible: [{}]".format(self.vpath)) return self.tx_404(True) - self.uparam["h"] = False + self.uparam["h"] = "" if "tree" in self.uparam: return self.tx_tree() if "delete" in self.uparam: - return self.handle_rm() + return self.handle_rm([]) if "move" in self.uparam: return self.handle_mv() @@ -486,7 +599,7 @@ class HttpCli(object): return self.tx_browser() - def handle_options(self): + def handle_options(self) -> bool: if self.do_log: self.log("OPTIONS " + self.req) @@ -501,7 +614,7 @@ class HttpCli(object): ) return True - def handle_put(self): + def handle_put(self) -> bool: self.log("PUT " + self.req) if self.headers.get("expect", "").lower() == "100-continue": @@ -512,7 +625,7 @@ class HttpCli(object): return self.handle_stash() - def handle_post(self): + def handle_post(self) -> bool: self.log("POST " + self.req) if self.headers.get("expect", "").lower() == "100-continue": @@ -549,16 +662,16 @@ class HttpCli(object): reader, _ = self.get_body_reader() for buf in reader: orig = buf.decode("utf-8", "replace") - m = "urlform_raw {} @ {}\n {}\n" - self.log(m.format(len(orig), self.vpath, orig)) + t = "urlform_raw {} @ {}\n {}\n" + self.log(t.format(len(orig), self.vpath, orig)) try: - plain = unquote(buf.replace(b"+", b" ")) - plain = plain.decode("utf-8", "replace") + zb = unquote(buf.replace(b"+", b" ")) + plain = zb.decode("utf-8", "replace") if buf.startswith(b"msg="): plain = plain[4:] - m = "urlform_dec {} @ {}\n {}\n" - self.log(m.format(len(plain), self.vpath, plain)) + t = "urlform_dec {} @ {}\n {}\n" + self.log(t.format(len(plain), self.vpath, plain)) except Exception as ex: self.log(repr(ex)) @@ -569,7 +682,7 @@ class HttpCli(object): raise Pebkac(405, "don't know how to handle POST({})".format(ctype)) - def get_body_reader(self): + def get_body_reader(self) -> tuple[Generator[bytes, None, None], int]: if "chunked" in self.headers.get("transfer-encoding", "").lower(): return read_socket_chunked(self.sr), -1 @@ -580,7 +693,8 @@ class HttpCli(object): else: return read_socket(self.sr, remains), remains - def dump_to_file(self): + def dump_to_file(self) -> tuple[int, str, str, int, str, str]: + # post_sz, sha_hex, sha_b64, remains, path, url reader, remains = self.get_body_reader() vfs, rem = self.asrv.vfs.get(self.vpath, self.uname, False, True) lim = vfs.get_dbv(rem)[0].lim @@ -595,7 +709,7 @@ class HttpCli(object): bos.makedirs(fdir) - open_ka = {"fun": open} + open_ka: dict[str, Any] = {"fun": open} open_a = ["wb", 512 * 1024] # user-request || config-force @@ -607,7 +721,7 @@ class HttpCli(object): ): fb = {"gz": 9, "xz": 0} # default/fallback level lv = {} # selected level - alg = None # selected algo (gz=preferred) + alg = "" # selected algo (gz=preferred) # user-prefs first if "gz" in self.uparam or "pk" in self.uparam: # def.pk @@ -615,8 +729,8 @@ class HttpCli(object): if "xz" in self.uparam: alg = "xz" if alg: - v = self.uparam.get(alg) - lv[alg] = fb[alg] if v is None else int(v) + zso = self.uparam.get(alg) + lv[alg] = fb[alg] if zso is None else int(zso) if alg not in vfs.flags: alg = "gz" if "gz" in vfs.flags else "xz" @@ -633,7 +747,7 @@ class HttpCli(object): except: pass - lv[alg] = lv.get(alg) or fb.get(alg) + lv[alg] = lv.get(alg) or fb.get(alg) or 0 self.log("compressing with {} level {}".format(alg, lv.get(alg))) if alg == "gz": @@ -656,8 +770,8 @@ class HttpCli(object): if not fn: fn = "put" + suffix - with ren_open(fn, *open_a, **params) as f: - f, fn = f["orz"] + with ren_open(fn, *open_a, **params) as zfw: + f, fn = zfw["orz"] path = os.path.join(fdir, fn) post_sz, sha_hex, sha_b64 = hashcopy(reader, f, self.args.s_wr_slp) @@ -674,8 +788,7 @@ class HttpCli(object): return post_sz, sha_hex, sha_b64, remains, path, "" vfs, rem = vfs.get_dbv(rem) - self.conn.hsrv.broker.put( - False, + self.conn.hsrv.broker.say( "up2k.hash_file", vfs.realpath, vfs.flags, @@ -705,16 +818,16 @@ class HttpCli(object): return post_sz, sha_hex, sha_b64, remains, path, url - def handle_stash(self): + def handle_stash(self) -> bool: post_sz, sha_hex, sha_b64, remains, path, url = self.dump_to_file() spd = self._spd(post_sz) - m = "{} wrote {}/{} bytes to {} # {}" - self.log(m.format(spd, post_sz, remains, path, sha_b64[:28])) # 21 - m = "{}\n{}\n{}\n{}\n".format(post_sz, sha_b64, sha_hex[:56], url) - self.reply(m.encode("utf-8")) + t = "{} wrote {}/{} bytes to {} # {}" + self.log(t.format(spd, post_sz, remains, path, sha_b64[:28])) # 21 + t = "{}\n{}\n{}\n{}\n".format(post_sz, sha_b64, sha_hex[:56], url) + self.reply(t.encode("utf-8")) return True - def _spd(self, nbytes, add=True): + def _spd(self, nbytes: int, add: bool = True) -> str: if add: self.conn.nbyte += nbytes @@ -722,7 +835,7 @@ class HttpCli(object): spd2 = get_spd(self.conn.nbyte, self.conn.t0) return "{} {} n{}".format(spd1, spd2, self.conn.nreq) - def handle_post_multipart(self): + def handle_post_multipart(self) -> bool: self.parser = MultipartParser(self.log, self.sr, self.headers) self.parser.parse() @@ -749,7 +862,8 @@ class HttpCli(object): raise Pebkac(422, 'invalid action "{}"'.format(act)) - def handle_zip_post(self): + def handle_zip_post(self) -> bool: + assert self.parser for k in ["zip", "tar"]: v = self.uparam.get(k) if v is not None: @@ -759,17 +873,17 @@ class HttpCli(object): raise Pebkac(422, "need zip or tar keyword") vn, rem = self.asrv.vfs.get(self.vpath, self.uname, True, False) - items = self.parser.require("files", 1024 * 1024) - if not items: + zs = self.parser.require("files", 1024 * 1024) + if not zs: raise Pebkac(422, "need files list") - items = items.replace("\r", "").split("\n") + items = zs.replace("\r", "").split("\n") items = [unquotep(x) for x in items if items] self.parser.drop() return self.tx_zip(k, v, vn, rem, items, self.args.ed) - def handle_post_json(self): + def handle_post_json(self) -> bool: try: remains = int(self.headers["content-length"]) except: @@ -836,14 +950,14 @@ class HttpCli(object): except: raise Pebkac(500, min_ex()) - x = self.conn.hsrv.broker.put(True, "up2k.handle_json", body) + x = self.conn.hsrv.broker.ask("up2k.handle_json", body) ret = x.get() ret = json.dumps(ret) self.log(ret) self.reply(ret.encode("utf-8"), mime="application/json") return True - def handle_search(self, body): + def handle_search(self, body: dict[str, Any]) -> bool: idx = self.conn.get_u2idx() if not hasattr(idx, "p_end"): raise Pebkac(500, "sqlite3 is not available on the server; cannot search") @@ -857,15 +971,15 @@ class HttpCli(object): continue seen[vfs] = True - vols.append([vfs.vpath, vfs.realpath, vfs.flags]) + vols.append((vfs.vpath, vfs.realpath, vfs.flags)) t0 = time.time() if idx.p_end: penalty = 0.7 t_idle = t0 - idx.p_end if idx.p_dur > 0.7 and t_idle < penalty: - m = "rate-limit {:.1f} sec, cost {:.2f}, idle {:.2f}" - raise Pebkac(429, m.format(penalty, idx.p_dur, t_idle)) + t = "rate-limit {:.1f} sec, cost {:.2f}, idle {:.2f}" + raise Pebkac(429, t.format(penalty, idx.p_dur, t_idle)) if "srch" in body: # search by up2k hashlist @@ -873,8 +987,8 @@ class HttpCli(object): vbody["hash"] = len(vbody["hash"]) self.log("qj: " + repr(vbody)) hits = idx.fsearch(vols, body) - msg = repr(hits) - taglist = {} + msg: Any = repr(hits) + taglist: list[str] = [] else: # search by query params q = body["q"] @@ -900,7 +1014,7 @@ class HttpCli(object): self.reply(r, mime="application/json") return True - def handle_post_binary(self): + def handle_post_binary(self) -> bool: try: remains = int(self.headers["content-length"]) except: @@ -915,7 +1029,7 @@ class HttpCli(object): vfs, _ = self.asrv.vfs.get(self.vpath, self.uname, False, True) ptop = (vfs.dbv or vfs).realpath - x = self.conn.hsrv.broker.put(True, "up2k.handle_chunk", ptop, wark, chash) + x = self.conn.hsrv.broker.ask("up2k.handle_chunk", ptop, wark, chash) response = x.get() chunksize, cstart, path, lastmod = response @@ -946,8 +1060,8 @@ class HttpCli(object): post_sz, _, sha_b64 = hashcopy(reader, f, self.args.s_wr_slp) if sha_b64 != chash: - m = "your chunk got corrupted somehow (received {} bytes); expected vs received hash:\n{}\n{}" - raise Pebkac(400, m.format(post_sz, chash, sha_b64)) + t = "your chunk got corrupted somehow (received {} bytes); expected vs received hash:\n{}\n{}" + raise Pebkac(400, t.format(post_sz, chash, sha_b64)) if len(cstart) > 1 and path != os.devnull: self.log( @@ -974,15 +1088,15 @@ class HttpCli(object): with self.mutex: self.u2fh.put(path, f) finally: - x = self.conn.hsrv.broker.put(True, "up2k.release_chunk", ptop, wark, chash) + x = self.conn.hsrv.broker.ask("up2k.release_chunk", ptop, wark, chash) x.get() # block client until released - x = self.conn.hsrv.broker.put(True, "up2k.confirm_chunk", ptop, wark, chash) - x = x.get() + x = self.conn.hsrv.broker.ask("up2k.confirm_chunk", ptop, wark, chash) + ztis = x.get() try: - num_left, fin_path = x + num_left, fin_path = ztis except: - self.loud_reply(x, status=500) + self.loud_reply(ztis, status=500) return False if not num_left and fpool: @@ -991,7 +1105,7 @@ class HttpCli(object): # windows cant rename open files if ANYWIN and path != fin_path and not self.args.nw: - self.conn.hsrv.broker.put(True, "up2k.finish_upload", ptop, wark).get() + self.conn.hsrv.broker.ask("up2k.finish_upload", ptop, wark).get() if not ANYWIN and not num_left: times = (int(time.time()), int(lastmod)) @@ -1006,7 +1120,8 @@ class HttpCli(object): self.reply(b"thank") return True - def handle_login(self): + def handle_login(self) -> bool: + assert self.parser pwd = self.parser.require("cppwd", 64) self.parser.drop() @@ -1021,11 +1136,11 @@ class HttpCli(object): dst += quotep(self.vpath) ck, msg = self.get_pwd_cookie(pwd) - html = self.j2("msg", h1=msg, h2='ack', redir=dst) + html = self.j2s("msg", h1=msg, h2='ack', redir=dst) self.reply(html.encode("utf-8"), headers={"Set-Cookie": ck}) return True - def get_pwd_cookie(self, pwd): + def get_pwd_cookie(self, pwd: str) -> tuple[str, str]: if pwd in self.asrv.iacct: msg = "login ok" dur = int(60 * 60 * self.args.logout) @@ -1038,9 +1153,10 @@ class HttpCli(object): if self.is_ancient: r = r.rsplit(" ", 1)[0] - return [r, msg] + return r, msg - def handle_mkdir(self): + def handle_mkdir(self) -> bool: + assert self.parser new_dir = self.parser.require("name", 512) self.parser.drop() @@ -1075,7 +1191,8 @@ class HttpCli(object): self.redirect(vpath) return True - def handle_new_md(self): + def handle_new_md(self) -> bool: + assert self.parser new_file = self.parser.require("name", 512) self.parser.drop() @@ -1102,7 +1219,8 @@ class HttpCli(object): self.redirect(vpath, "?edit") return True - def handle_plain_upload(self): + def handle_plain_upload(self) -> bool: + assert self.parser nullwrite = self.args.nw vfs, rem = self.asrv.vfs.get(self.vpath, self.uname, False, True) self._assert_safe_rem(rem) @@ -1116,17 +1234,21 @@ class HttpCli(object): if not nullwrite: bos.makedirs(fdir_base) - files = [] + files: list[tuple[int, str, str, str, str, str]] = [] + # sz, sha_hex, sha_b64, p_file, fname, abspath errmsg = "" t0 = time.time() try: + assert self.parser.gen for nfile, (p_field, p_file, p_data) in enumerate(self.parser.gen): if not p_file: self.log("discarding incoming file without filename") # fallthrough fdir = fdir_base - fname = sanitize_fn(p_file, "", [".prologue.html", ".epilogue.html"]) + fname = sanitize_fn( + p_file or "", "", [".prologue.html", ".epilogue.html"] + ) if p_file and not nullwrite: if not bos.path.isdir(fdir): raise Pebkac(404, "that folder does not exist") @@ -1143,8 +1265,8 @@ class HttpCli(object): lim.chk_nup(self.ip) try: - with ren_open(fname, "wb", 512 * 1024, **open_args) as f: - f, fname = f["orz"] + with ren_open(fname, "wb", 512 * 1024, **open_args) as zfw: + f, fname = zfw["orz"] abspath = os.path.join(fdir, fname) self.log("writing to {}".format(abspath)) sz, sha_hex, sha_b64 = hashcopy(p_data, f, self.args.s_wr_slp) @@ -1158,12 +1280,14 @@ class HttpCli(object): lim.chk_sz(sz) except: bos.unlink(abspath) + fname = os.devnull raise - files.append([sz, sha_hex, sha_b64, p_file, fname, abspath]) + files.append( + (sz, sha_hex, sha_b64, p_file or "(discarded)", fname, abspath) + ) dbv, vrem = vfs.get_dbv(rem) - self.conn.hsrv.broker.put( - False, + self.conn.hsrv.broker.say( "up2k.hash_file", dbv.realpath, dbv.flags, @@ -1192,7 +1316,7 @@ class HttpCli(object): except Pebkac as ex: errmsg = vol_san( - self.asrv.vfs.all_vols.values(), unicode(ex).encode("utf-8") + list(self.asrv.vfs.all_vols.values()), unicode(ex).encode("utf-8") ).decode("utf-8") td = max(0.1, time.time() - t0) @@ -1205,7 +1329,12 @@ class HttpCli(object): status = "ERROR" msg = "{} // {} bytes // {:.3f} MiB/s\n".format(status, sz_total, spd) - jmsg = {"status": status, "sz": sz_total, "mbps": round(spd, 3), "files": []} + jmsg: dict[str, Any] = { + "status": status, + "sz": sz_total, + "mbps": round(spd, 3), + "files": [], + } if errmsg: msg += errmsg + "\n" @@ -1260,17 +1389,17 @@ class HttpCli(object): ft = "{}\n{}\n{}\n".format(ft, msg.rstrip(), errmsg) f.write(ft.encode("utf-8")) - status = 400 if errmsg else 200 + sc = 400 if errmsg else 200 if "j" in self.uparam: jtxt = json.dumps(jmsg, indent=2, sort_keys=True).encode("utf-8", "replace") - self.reply(jtxt, mime="application/json", status=status) + self.reply(jtxt, mime="application/json", status=sc) else: self.redirect( self.vpath, msg=msg, flavor="return to", click=False, - status=status, + status=sc, ) if errmsg: @@ -1279,7 +1408,8 @@ class HttpCli(object): self.parser.drop() return True - def handle_text_upload(self): + def handle_text_upload(self) -> bool: + assert self.parser try: cli_lastmod3 = int(self.parser.require("lastmod", 16)) except: @@ -1314,7 +1444,8 @@ class HttpCli(object): self.reply(response.encode("utf-8")) return True - srv_lastmod = srv_lastmod3 = -1 + srv_lastmod = -1.0 + srv_lastmod3 = -1 try: st = bos.stat(fp) srv_lastmod = st.st_mtime @@ -1329,7 +1460,7 @@ class HttpCli(object): if not same_lastmod: # some filesystems/transports limit precision to 1sec, hopefully floored same_lastmod = ( - srv_lastmod == int(srv_lastmod) + srv_lastmod == int(cli_lastmod3 / 1000) and cli_lastmod3 > srv_lastmod3 and cli_lastmod3 - srv_lastmod3 < 1000 ) @@ -1360,6 +1491,7 @@ class HttpCli(object): pass bos.rename(fp, os.path.join(mdir, ".hist", mfile2)) + assert self.parser.gen p_field, _, p_data = next(self.parser.gen) if p_field != "body": raise Pebkac(400, "expected body, got {}".format(p_field)) @@ -1388,7 +1520,7 @@ class HttpCli(object): self.reply(response.encode("utf-8")) return True - def _chk_lastmod(self, file_ts): + def _chk_lastmod(self, file_ts: int) -> tuple[str, bool]: file_lastmod = http_ts(file_ts) cli_lastmod = self.headers.get("if-modified-since") if cli_lastmod: @@ -1408,7 +1540,7 @@ class HttpCli(object): return file_lastmod, True - def tx_file(self, req_path): + def tx_file(self, req_path: str) -> bool: status = 200 logmsg = "{:4} {} ".format("", self.req) logtail = "" @@ -1417,7 +1549,7 @@ class HttpCli(object): # if request is for foo.js, check if we have foo.js.{gz,br} file_ts = 0 - editions = {} + editions: dict[str, tuple[str, int]] = {} for ext in ["", ".gz", ".br"]: try: fs_path = req_path + ext @@ -1425,8 +1557,8 @@ class HttpCli(object): if stat.S_ISDIR(st.st_mode): continue - file_ts = max(file_ts, st.st_mtime) - editions[ext or "plain"] = [fs_path, st.st_size] + file_ts = max(file_ts, int(st.st_mtime)) + editions[ext or "plain"] = (fs_path, st.st_size) except: pass if not self.vpath.startswith(".cpr/"): @@ -1526,8 +1658,8 @@ class HttpCli(object): use_sendfile = False if decompress: - open_func = gzip.open - open_args = [fsenc(fs_path), "rb"] + open_func: Any = gzip.open + open_args: list[Any] = [fsenc(fs_path), "rb"] # Content-Length := original file size upper = gzip_orig_sz(fs_path) else: @@ -1551,7 +1683,7 @@ class HttpCli(object): if "txt" in self.uparam: mime = "text/plain; charset={}".format(self.uparam["txt"] or "utf-8") elif "mime" in self.uparam: - mime = self.uparam.get("mime") + mime = str(self.uparam.get("mime")) else: mime = guess_mime(req_path) @@ -1583,19 +1715,18 @@ class HttpCli(object): return ret - def tx_zip(self, fmt, uarg, vn, rem, items, dots): + def tx_zip( + self, fmt: str, uarg: str, vn: VFS, rem: str, items: list[str], dots: bool + ) -> bool: if self.args.no_zip: raise Pebkac(400, "not enabled") logmsg = "{:4} {} ".format("", self.req) self.keepalive = False - if not uarg: - uarg = "" - if fmt == "tar": mime = "application/x-tar" - packer = StreamTar + packer: Type[StreamArc] = StreamTar else: mime = "application/zip" packer = StreamZip @@ -1609,24 +1740,25 @@ class HttpCli(object): safe = (string.ascii_letters + string.digits).replace("%", "") afn = "".join([x if x in safe.replace('"', "") else "_" for x in fn]) bascii = unicode(safe).encode("utf-8") - ufn = fn.encode("utf-8", "xmlcharrefreplace") - if PY2: - ufn = [unicode(x) if x in bascii else "%{:02x}".format(ord(x)) for x in ufn] - else: - ufn = [ + zb = fn.encode("utf-8", "xmlcharrefreplace") + if not PY2: + zbl = [ chr(x).encode("utf-8") if x in bascii else "%{:02x}".format(x).encode("ascii") - for x in ufn + for x in zb ] - ufn = b"".join(ufn).decode("ascii") + else: + zbl = [unicode(x) if x in bascii else "%{:02x}".format(ord(x)) for x in zb] + + ufn = b"".join(zbl).decode("ascii") cdis = "attachment; filename=\"{}.{}\"; filename*=UTF-8''{}.{}" cdis = cdis.format(afn, fmt, ufn, fmt) self.log(cdis) self.send_headers(None, mime=mime, headers={"Content-Disposition": cdis}) - fgen = vn.zipgen(rem, items, self.uname, dots, not self.args.no_scandir) + fgen = vn.zipgen(rem, set(items), self.uname, dots, not self.args.no_scandir) # for f in fgen: print(repr({k: f[k] for k in ["vp", "ap"]})) bgen = packer(self.log, fgen, utf8="utf" in uarg, pre_crc="crc" in uarg) bsent = 0 @@ -1645,7 +1777,7 @@ class HttpCli(object): self.log("{}, {}".format(logmsg, spd)) return True - def tx_ico(self, ext, exact=False): + def tx_ico(self, ext: str, exact: bool = False) -> bool: self.permit_caching() if ext.endswith("/"): ext = "folder" @@ -1674,7 +1806,7 @@ class HttpCli(object): self.reply(ico, mime=mime, headers={"Last-Modified": lm}) return True - def tx_md(self, fs_path): + def tx_md(self, fs_path: str) -> bool: logmsg = "{:4} {} ".format("", self.req) if not self.can_write: @@ -1683,7 +1815,7 @@ class HttpCli(object): tpl = "mde" if "edit2" in self.uparam else "md" html_path = os.path.join(E.mod, "web", "{}.html".format(tpl)) - template = self.j2(tpl) + template = self.j2j(tpl) st = bos.stat(fs_path) ts_md = st.st_mtime @@ -1694,7 +1826,7 @@ class HttpCli(object): sz_md = 0 for buf in yieldfile(fs_path): sz_md += len(buf) - for c, v in [[b"&", 4], [b"<", 3], [b">", 3]]: + for c, v in [(b"&", 4), (b"<", 3), (b">", 3)]: sz_md += (len(buf) - len(buf.replace(c, b""))) * v file_ts = max(ts_md, ts_html, E.t0) @@ -1720,8 +1852,8 @@ class HttpCli(object): "md": boundary, "arg_base": arg_base, } - html = template.render(**targs).encode("utf-8", "replace") - html = html.split(boundary.encode("utf-8")) + zs = template.render(**targs).encode("utf-8", "replace") + html = zs.split(boundary.encode("utf-8")) if len(html) != 2: raise Exception("boundary appears in " + html_path) @@ -1750,7 +1882,7 @@ class HttpCli(object): return True - def tx_mounts(self): + def tx_mounts(self) -> bool: suf = self.urlq({}, ["h"]) avol = [x for x in self.wvol if x in self.rvol] rvol, wvol, avol = [ @@ -1759,7 +1891,7 @@ class HttpCli(object): ] if avol and not self.args.no_rescan: - x = self.conn.hsrv.broker.put(True, "up2k.get_state") + x = self.conn.hsrv.broker.ask("up2k.get_state") vs = json.loads(x.get()) vstate = {("/" + k).rstrip("/") + "/": v for k, v in vs["volstate"].items()} else: @@ -1787,11 +1919,11 @@ class HttpCli(object): for v in wvol: txt += "\n " + v - txt = txt.encode("utf-8", "replace") + b"\n" - self.reply(txt, mime="text/plain; charset=utf-8") + zb = txt.encode("utf-8", "replace") + b"\n" + self.reply(zb, mime="text/plain; charset=utf-8") return True - html = self.j2( + html = self.j2s( "splash", this=self, qvpath=quotep(self.vpath), @@ -1809,38 +1941,41 @@ class HttpCli(object): self.reply(html.encode("utf-8")) return True - def set_k304(self): + def set_k304(self) -> bool: ck = gencookie("k304", self.uparam["k304"], 60 * 60 * 24 * 365) self.out_headerlist.append(("Set-Cookie", ck)) self.redirect("", "?h#cc") + return True - def set_am_js(self): + def set_am_js(self) -> bool: v = "n" if self.uparam["am_js"] == "n" else "y" ck = gencookie("js", v, 60 * 60 * 24 * 365) self.out_headerlist.append(("Set-Cookie", ck)) self.reply(b"promoted\n") + return True - def set_cfg_reset(self): + def set_cfg_reset(self) -> bool: for k in ("k304", "js", "cppwd"): self.out_headerlist.append(("Set-Cookie", gencookie(k, "x", None))) self.redirect("", "?h#cc") + return True - def tx_404(self, is_403=False): + def tx_404(self, is_403: bool = False) -> bool: rc = 404 if self.args.vague_403: - m = '
or maybe you don\'t have access -- try logging in or go home
' + t = 'or maybe you don\'t have access -- try logging in or go home
' elif is_403: - m = 'you\'ll have to log in or go home
' + t = 'you\'ll have to log in or go home
' rc = 403 else: - m = '{}\n{}".format(time.time(), html_escape(alltrace())) self.reply(ret.encode("utf-8")) + return True - def tx_tree(self): + def tx_tree(self) -> bool: top = self.uparam["tree"] or "" dst = self.vpath if top in [".", ".."]: @@ -1898,12 +2034,12 @@ class HttpCli(object): dst = dst[len(top) + 1 :] ret = self.gen_tree(top, dst) - ret = json.dumps(ret) - self.reply(ret.encode("utf-8"), mime="application/json") + zs = json.dumps(ret) + self.reply(zs.encode("utf-8"), mime="application/json") return True - def gen_tree(self, top, target): - ret = {} + def gen_tree(self, top: str, target: str) -> dict[str, Any]: + ret: dict[str, Any] = {} excl = None if target: excl, target = (target.split("/", 1) + [""])[:2] @@ -1921,26 +2057,26 @@ class HttpCli(object): for v in self.rvol: d1, d2 = v.rsplit("/", 1) if "/" in v else ["", v] if d1 == top: - vfs_virt[d2] = 0 + vfs_virt[d2] = vn # typechk, value never read dirs = [] - vfs_ls = [x[0] for x in vfs_ls if stat.S_ISDIR(x[1].st_mode)] + dirnames = [x[0] for x in vfs_ls if stat.S_ISDIR(x[1].st_mode)] if not self.args.ed or "dots" not in self.uparam: - vfs_ls = exclude_dotfiles(vfs_ls) + dirnames = exclude_dotfiles(dirnames) - for fn in [x for x in vfs_ls if x != excl]: + for fn in [x for x in dirnames if x != excl]: dirs.append(quotep(fn)) - for x in vfs_virt.keys(): + for x in vfs_virt: if x != excl: dirs.append(x) ret["a"] = dirs return ret - def tx_ups(self): + def tx_ups(self) -> bool: if not self.args.unpost: raise Pebkac(400, "the unpost feature is disabled in server config") @@ -1952,7 +2088,7 @@ class HttpCli(object): lm = "ups [{}]".format(filt) self.log(lm) - ret = [] + ret: list[dict[str, Any]] = [] t0 = time.time() lim = time.time() - self.args.unpost for vol in self.asrv.vfs.all_vols.values(): @@ -1968,17 +2104,18 @@ class HttpCli(object): ret.append({"vp": quotep(vp), "sz": sz, "at": at}) if len(ret) > 3000: - ret.sort(key=lambda x: x["at"], reverse=True) + ret.sort(key=lambda x: x["at"], reverse=True) # type: ignore ret = ret[:2000] - ret.sort(key=lambda x: x["at"], reverse=True) + ret.sort(key=lambda x: x["at"], reverse=True) # type: ignore ret = ret[:2000] jtxt = json.dumps(ret, indent=2, sort_keys=True).encode("utf-8", "replace") self.log("{} #{} {:.2f}sec".format(lm, len(ret), time.time() - t0)) self.reply(jtxt, mime="application/json") + return True - def handle_rm(self, req=None): + def handle_rm(self, req: list[str]) -> bool: if not req and not self.can_delete: raise Pebkac(403, "not allowed for user " + self.uname) @@ -1988,10 +2125,11 @@ class HttpCli(object): if not req: req = [self.vpath] - x = self.conn.hsrv.broker.put(True, "up2k.handle_rm", self.uname, self.ip, req) + x = self.conn.hsrv.broker.ask("up2k.handle_rm", self.uname, self.ip, req) self.loud_reply(x.get()) + return True - def handle_mv(self): + def handle_mv(self) -> bool: if not self.can_move: raise Pebkac(403, "not allowed for user " + self.uname) @@ -2006,12 +2144,11 @@ class HttpCli(object): # x-www-form-urlencoded (url query part) uses # either + or %20 for 0x20 so handle both dst = unquotep(dst.replace("+", " ")) - x = self.conn.hsrv.broker.put( - True, "up2k.handle_mv", self.uname, self.vpath, dst - ) + x = self.conn.hsrv.broker.ask("up2k.handle_mv", self.uname, self.vpath, dst) self.loud_reply(x.get()) + return True - def tx_ls(self, ls): + def tx_ls(self, ls: dict[str, Any]) -> bool: dirs = ls["dirs"] files = ls["files"] arg = self.uparam["ls"] @@ -2055,17 +2192,17 @@ class HttpCli(object): x["name"] = n fmt = fmt.format(len(nfmt.format(biggest))) - ret = [ + retl = [ "# {}: {}".format(x, ls[x]) for x in ["acct", "perms", "srvinf"] if x in ls ] - ret += [ + retl += [ fmt.format(x["dt"], x["sz"], x["name"]) for y in [dirs, files] for x in y ] - ret = "\n".join(ret) + ret = "\n".join(retl) mime = "text/plain; charset=utf-8" else: [x.pop(k) for k in ["name", "dt"] for y in [dirs, files] for x in y] @@ -2076,7 +2213,7 @@ class HttpCli(object): self.reply(ret.encode("utf-8", "replace") + b"\n", mime=mime) return True - def tx_browser(self): + def tx_browser(self) -> bool: vpath = "" vpnodes = [["", "/"]] if self.vpath: @@ -2164,7 +2301,7 @@ class HttpCli(object): if WINDOWS: try: bfree = ctypes.c_ulonglong(0) - ctypes.windll.kernel32.GetDiskFreeSpaceExW( + ctypes.windll.kernel32.GetDiskFreeSpaceExW( # type: ignore ctypes.c_wchar_p(abspath), None, None, ctypes.pointer(bfree) ) srv_info.append(humansize(bfree.value) + " free") @@ -2179,7 +2316,7 @@ class HttpCli(object): except: pass - srv_info = " // ".join(srv_info) + srv_infot = " // ".join(srv_info) perms = [] if self.can_read: @@ -2223,7 +2360,7 @@ class HttpCli(object): "dirs": [], "files": [], "taglist": [], - "srvinf": srv_info, + "srvinf": srv_infot, "acct": self.uname, "idx": ("e2d" in vn.flags), "perms": perms, @@ -2251,7 +2388,7 @@ class HttpCli(object): "logues": logues, "readme": readme, "title": html_escape(self.vpath, crlf=True), - "srv_info": srv_info, + "srv_info": srv_infot, "lang": self.args.lang, "dtheme": self.args.theme, "themes": self.args.themes, @@ -2267,7 +2404,7 @@ class HttpCli(object): if "zip" in self.uparam or "tar" in self.uparam: raise Pebkac(403) - html = self.j2(tpl, **j2a) + html = self.j2s(tpl, **j2a) self.reply(html.encode("utf-8", "replace")) return True @@ -2280,11 +2417,12 @@ class HttpCli(object): rem, self.uname, not self.args.no_scandir, [[True], [False, True]] ) stats = {k: v for k, v in vfs_ls} - vfs_ls = [x[0] for x in vfs_ls] - vfs_ls.extend(vfs_virt.keys()) + ls_names = [x[0] for x in vfs_ls] + ls_names.extend(list(vfs_virt.keys())) # check for old versions of files, - hist = {} # [num-backups, most-recent, hist-path] + # [num-backups, most-recent, hist-path] + hist: dict[str, tuple[int, float, str]] = {} histdir = os.path.join(fsroot, ".hist") ptn = re.compile(r"(.*)\.([0-9]+\.[0-9]{3})(\.[^\.]+)$") try: @@ -2294,14 +2432,14 @@ class HttpCli(object): continue fn = m.group(1) + m.group(3) - n, ts, _ = hist.get(fn, [0, 0, ""]) - hist[fn] = [n + 1, max(ts, float(m.group(2))), hfn] + n, ts, _ = hist.get(fn, (0, 0, "")) + hist[fn] = (n + 1, max(ts, float(m.group(2))), hfn) except: pass # show dotfiles if permitted and requested if not self.args.ed or "dots" not in self.uparam: - vfs_ls = exclude_dotfiles(vfs_ls) + ls_names = exclude_dotfiles(ls_names) icur = None if "e2t" in vn.flags: @@ -2312,7 +2450,7 @@ class HttpCli(object): dirs = [] files = [] - for fn in vfs_ls: + for fn in ls_names: base = "" href = fn if not is_ls and not is_js and not self.trailing_slash and vpath: @@ -2339,14 +2477,14 @@ class HttpCli(object): margin = 'zip'.format(quotep(href)) elif fn in hist: margin = '#{}'.format( - base, html_escape(hist[fn][2], quote=True, crlf=True), hist[fn][0] + base, html_escape(hist[fn][2], quot=True, crlf=True), hist[fn][0] ) else: margin = "-" sz = inf.st_size - dt = datetime.utcfromtimestamp(inf.st_mtime) - dt = dt.strftime("%Y-%m-%d %H:%M:%S") + zd = datetime.utcfromtimestamp(inf.st_mtime) + dt = zd.strftime("%Y-%m-%d %H:%M:%S") try: ext = "---" if is_dir else fn.rsplit(".", 1)[1] @@ -2380,11 +2518,11 @@ class HttpCli(object): files.append(item) item["rd"] = rem - taglist = {} - for f in files: - fn = f["name"] - rd = f["rd"] - del f["rd"] + tagset: set[str] = set() + for fe in files: + fn = fe["name"] + rd = fe["rd"] + del fe["rd"] if not icur: break @@ -2403,12 +2541,12 @@ class HttpCli(object): args = s3enc(idx.mem_cur, rd, fn) r = icur.execute(q, args).fetchone() except: - m = "tag list error, {}/{}\n{}" - self.log(m.format(rd, fn, min_ex())) + t = "tag list error, {}/{}\n{}" + self.log(t.format(rd, fn, min_ex())) break - tags = {} - f["tags"] = tags + tags: dict[str, Any] = {} + fe["tags"] = tags if not r: continue @@ -2417,17 +2555,19 @@ class HttpCli(object): q = "select k, v from mt where w = ? and +k != 'x'" try: for k, v in icur.execute(q, (w,)): - taglist[k] = True + tagset.add(k) tags[k] = v except: - m = "tag read error, {}/{} [{}]:\n{}" - self.log(m.format(rd, fn, w, min_ex())) + t = "tag read error, {}/{} [{}]:\n{}" + self.log(t.format(rd, fn, w, min_ex())) break if icur: - taglist = [k for k in vn.flags.get("mte", "").split(",") if k in taglist] - for f in dirs: - f["tags"] = {} + taglist = [k for k in vn.flags.get("mte", "").split(",") if k in tagset] + for fe in dirs: + fe["tags"] = {} + else: + taglist = list(tagset) if is_ls: ls_ret["dirs"] = dirs @@ -2480,6 +2620,6 @@ class HttpCli(object): if self.args.css_browser: j2a["css"] = self.args.css_browser - html = self.j2(tpl, **j2a) + html = self.j2s(tpl, **j2a) self.reply(html.encode("utf-8", "replace")) return True diff --git a/copyparty/httpconn.py b/copyparty/httpconn.py index 85f5f701..067e2d31 100644 --- a/copyparty/httpconn.py +++ b/copyparty/httpconn.py @@ -1,25 +1,36 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import re +import argparse # typechk import os -import time +import re import socket +import threading # typechk +import time -HAVE_SSL = True try: + HAVE_SSL = True import ssl except: HAVE_SSL = False -from .__init__ import E -from .util import Unrecv +from . import util as Util +from .__init__ import TYPE_CHECKING, E +from .authsrv import AuthSrv # typechk from .httpcli import HttpCli -from .u2idx import U2idx +from .ico import Ico +from .mtag import HAVE_FFMPEG from .th_cli import ThumbCli from .th_srv import HAVE_PIL, HAVE_VIPS -from .mtag import HAVE_FFMPEG -from .ico import Ico +from .u2idx import U2idx + +try: + from typing import Optional, Pattern, Union +except: + pass + +if TYPE_CHECKING: + from .httpsrv import HttpSrv class HttpConn(object): @@ -28,32 +39,37 @@ class HttpConn(object): creates an HttpCli for each request (Connection: Keep-Alive) """ - def __init__(self, sck, addr, hsrv): + def __init__( + self, sck: socket.socket, addr: tuple[str, int], hsrv: "HttpSrv" + ) -> None: self.s = sck - self.sr = None # Type: Unrecv + self.sr: Optional[Util._Unrecv] = None self.addr = addr self.hsrv = hsrv - self.mutex = hsrv.mutex - self.args = hsrv.args - self.asrv = hsrv.asrv + self.mutex: threading.Lock = hsrv.mutex # mypy404 + self.args: argparse.Namespace = hsrv.args # mypy404 + self.asrv: AuthSrv = hsrv.asrv # mypy404 self.cert_path = hsrv.cert_path - self.u2fh = hsrv.u2fh + self.u2fh: Util.FHC = hsrv.u2fh # mypy404 enth = (HAVE_PIL or HAVE_VIPS or HAVE_FFMPEG) and not self.args.no_thumb - self.thumbcli = ThumbCli(hsrv) if enth else None - self.ico = Ico(self.args) + self.thumbcli: Optional[ThumbCli] = ThumbCli(hsrv) if enth else None # mypy404 + self.ico: Ico = Ico(self.args) # mypy404 - self.t0 = time.time() + self.t0: float = time.time() # mypy404 self.stopping = False - self.nreq = 0 - self.nbyte = 0 - self.u2idx = None - self.log_func = hsrv.log - self.lf_url = re.compile(self.args.lf_url) if self.args.lf_url else None + self.nreq: int = 0 # mypy404 + self.nbyte: int = 0 # mypy404 + self.u2idx: Optional[U2idx] = None + self.log_func: Util.RootLogger = hsrv.log # mypy404 + self.log_src: str = "httpconn" # mypy404 + self.lf_url: Optional[Pattern[str]] = ( + re.compile(self.args.lf_url) if self.args.lf_url else None + ) # mypy404 self.set_rproxy() - def shutdown(self): + def shutdown(self) -> None: self.stopping = True try: self.s.shutdown(socket.SHUT_RDWR) @@ -61,7 +77,7 @@ class HttpConn(object): except: pass - def set_rproxy(self, ip=None): + def set_rproxy(self, ip: Optional[str] = None) -> str: if ip is None: color = 36 ip = self.addr[0] @@ -74,35 +90,35 @@ class HttpConn(object): self.log_src = "{} \033[{}m{}".format(ip, color, self.addr[1]).ljust(26) return self.log_src - def respath(self, res_name): + def respath(self, res_name: str) -> str: return os.path.join(E.mod, "web", res_name) - def log(self, msg, c=0): + def log(self, msg: str, c: Union[int, str] = 0) -> None: self.log_func(self.log_src, msg, c) - def get_u2idx(self): + def get_u2idx(self) -> U2idx: if not self.u2idx: self.u2idx = U2idx(self) return self.u2idx - def _detect_https(self): + def _detect_https(self) -> bool: method = None if self.cert_path: try: method = self.s.recv(4, socket.MSG_PEEK) except socket.timeout: - return + return False except AttributeError: # jython does not support msg_peek; forget about https method = self.s.recv(4) - self.sr = Unrecv(self.s, self.log) + self.sr = Util.Unrecv(self.s, self.log) self.sr.buf = method # jython used to do this, they stopped since it's broken # but reimplementing sendall is out of scope for now if not getattr(self.s, "sendall", None): - self.s.sendall = self.s.send + self.s.sendall = self.s.send # type: ignore if len(method) != 4: err = "need at least 4 bytes in the first packet; got {}".format( @@ -112,17 +128,18 @@ class HttpConn(object): self.log(err) self.s.send(b"HTTP/1.1 400 Bad Request\r\n\r\n" + err.encode("utf-8")) - return + return False return method not in [None, b"GET ", b"HEAD", b"POST", b"PUT ", b"OPTI"] - def run(self): + def run(self) -> None: self.sr = None if self.args.https_only: is_https = True elif self.args.http_only or not HAVE_SSL: is_https = False else: + # raise Exception("asdf") is_https = self._detect_https() if is_https: @@ -151,14 +168,15 @@ class HttpConn(object): self.s = ctx.wrap_socket(self.s, server_side=True) msg = [ "\033[1;3{:d}m{}".format(c, s) - for c, s in zip([0, 5, 0], self.s.cipher()) + for c, s in zip([0, 5, 0], self.s.cipher()) # type: ignore ] self.log(" ".join(msg) + "\033[0m") if self.args.ssl_dbg and hasattr(self.s, "shared_ciphers"): - overlap = [y[::-1] for y in self.s.shared_ciphers()] - lines = [str(x) for x in (["TLS cipher overlap:"] + overlap)] - self.log("\n".join(lines)) + ciphers = self.s.shared_ciphers() + assert ciphers + overlap = [str(y[::-1]) for y in ciphers] + self.log("TLS cipher overlap:" + "\n".join(overlap)) for k, v in [ ["compression", self.s.compression()], ["ALPN proto", self.s.selected_alpn_protocol()], @@ -183,7 +201,7 @@ class HttpConn(object): return if not self.sr: - self.sr = Unrecv(self.s, self.log) + self.sr = Util.Unrecv(self.s, self.log) while not self.stopping: self.nreq += 1 diff --git a/copyparty/httpsrv.py b/copyparty/httpsrv.py index 04d2146c..fd449b5f 100644 --- a/copyparty/httpsrv.py +++ b/copyparty/httpsrv.py @@ -1,13 +1,15 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import os -import sys -import time -import math import base64 +import math +import os import socket +import sys import threading +import time + +import queue try: import jinja2 @@ -26,15 +28,18 @@ except ImportError: ) sys.exit(1) -from .__init__ import E, PY2, MACOS -from .util import FHC, spack, min_ex, start_stackmon, start_log_thrs +from .__init__ import MACOS, TYPE_CHECKING, E from .bos import bos from .httpconn import HttpConn +from .util import FHC, min_ex, spack, start_log_thrs, start_stackmon -if PY2: - import Queue as queue -else: - import queue +if TYPE_CHECKING: + from .broker_util import BrokerCli + +try: + from typing import Any, Optional +except: + pass class HttpSrv(object): @@ -43,7 +48,7 @@ class HttpSrv(object): relying on MpSrv for performance (HttpSrv is just plain threads) """ - def __init__(self, broker, nid): + def __init__(self, broker: "BrokerCli", nid: Optional[int]) -> None: self.broker = broker self.nid = nid self.args = broker.args @@ -58,17 +63,19 @@ class HttpSrv(object): self.tp_nthr = 0 # actual self.tp_ncli = 0 # fading - self.tp_time = None # latest worker collect - self.tp_q = None if self.args.no_htp else queue.LifoQueue() - self.t_periodic = None + self.tp_time = 0.0 # latest worker collect + self.tp_q: Optional[queue.LifoQueue[Any]] = ( + None if self.args.no_htp else queue.LifoQueue() + ) + self.t_periodic: Optional[threading.Thread] = None self.u2fh = FHC() - self.srvs = [] + self.srvs: list[socket.socket] = [] self.ncli = 0 # exact - self.clients = {} # laggy + self.clients: set[HttpConn] = set() # laggy self.nclimax = 0 - self.cb_ts = 0 - self.cb_v = 0 + self.cb_ts = 0.0 + self.cb_v = "" env = jinja2.Environment() env.loader = jinja2.FileSystemLoader(os.path.join(E.mod, "web")) @@ -82,7 +89,7 @@ class HttpSrv(object): if bos.path.exists(cert_path): self.cert_path = cert_path else: - self.cert_path = None + self.cert_path = "" if self.tp_q: self.start_threads(4) @@ -94,19 +101,19 @@ class HttpSrv(object): if self.args.log_thrs: start_log_thrs(self.log, self.args.log_thrs, nid) - self.th_cfg = {} # type: dict[str, Any] + self.th_cfg: dict[str, Any] = {} t = threading.Thread(target=self.post_init) t.daemon = True t.start() - def post_init(self): + def post_init(self) -> None: try: - x = self.broker.put(True, "thumbsrv.getcfg") + x = self.broker.ask("thumbsrv.getcfg") self.th_cfg = x.get() except: pass - def start_threads(self, n): + def start_threads(self, n: int) -> None: self.tp_nthr += n if self.args.log_htp: self.log(self.name, "workers += {} = {}".format(n, self.tp_nthr), 6) @@ -119,15 +126,16 @@ class HttpSrv(object): thr.daemon = True thr.start() - def stop_threads(self, n): + def stop_threads(self, n: int) -> None: self.tp_nthr -= n if self.args.log_htp: self.log(self.name, "workers -= {} = {}".format(n, self.tp_nthr), 6) + assert self.tp_q for _ in range(n): self.tp_q.put(None) - def periodic(self): + def periodic(self) -> None: while True: time.sleep(2 if self.tp_ncli or self.ncli else 10) with self.mutex: @@ -141,7 +149,7 @@ class HttpSrv(object): self.t_periodic = None return - def listen(self, sck, nlisteners): + def listen(self, sck: socket.socket, nlisteners: int) -> None: ip, port = sck.getsockname() self.srvs.append(sck) self.nclimax = math.ceil(self.args.nc * 1.0 / nlisteners) @@ -153,15 +161,15 @@ class HttpSrv(object): t.daemon = True t.start() - def thr_listen(self, srv_sck): + def thr_listen(self, srv_sck: socket.socket) -> None: """listens on a shared tcp server""" ip, port = srv_sck.getsockname() fno = srv_sck.fileno() msg = "subscribed @ {}:{} f{}".format(ip, port, fno) self.log(self.name, msg) - def fun(): - self.broker.put(False, "cb_httpsrv_up") + def fun() -> None: + self.broker.say("cb_httpsrv_up") threading.Thread(target=fun).start() @@ -185,21 +193,21 @@ class HttpSrv(object): continue if self.args.log_conn: - m = "|{}C-acc2 \033[0;36m{} \033[3{}m{}".format( + t = "|{}C-acc2 \033[0;36m{} \033[3{}m{}".format( "-" * 3, ip, port % 8, port ) - self.log("%s %s" % addr, m, c="1;30") + self.log("%s %s" % addr, t, c="1;30") self.accept(sck, addr) - def accept(self, sck, addr): + def accept(self, sck: socket.socket, addr: tuple[str, int]) -> None: """takes an incoming tcp connection and creates a thread to handle it""" now = time.time() if now - (self.tp_time or now) > 300: - m = "httpserver threadpool died: tpt {:.2f}, now {:.2f}, nthr {}, ncli {}" - self.log(self.name, m.format(self.tp_time, now, self.tp_nthr, self.ncli), 1) - self.tp_time = None + t = "httpserver threadpool died: tpt {:.2f}, now {:.2f}, nthr {}, ncli {}" + self.log(self.name, t.format(self.tp_time, now, self.tp_nthr, self.ncli), 1) + self.tp_time = 0 self.tp_q = None with self.mutex: @@ -209,10 +217,10 @@ class HttpSrv(object): if self.nid: name += "-{}".format(self.nid) - t = threading.Thread(target=self.periodic, name=name) - self.t_periodic = t - t.daemon = True - t.start() + thr = threading.Thread(target=self.periodic, name=name) + self.t_periodic = thr + thr.daemon = True + thr.start() if self.tp_q: self.tp_time = self.tp_time or now @@ -224,8 +232,8 @@ class HttpSrv(object): return if not self.args.no_htp: - m = "looks like the httpserver threadpool died; please make an issue on github and tell me the story of how you pulled that off, thanks and dog bless\n" - self.log(self.name, m, 1) + t = "looks like the httpserver threadpool died; please make an issue on github and tell me the story of how you pulled that off, thanks and dog bless\n" + self.log(self.name, t, 1) thr = threading.Thread( target=self.thr_client, @@ -235,14 +243,15 @@ class HttpSrv(object): thr.daemon = True thr.start() - def thr_poolw(self): + def thr_poolw(self) -> None: + assert self.tp_q while True: task = self.tp_q.get() if not task: break with self.mutex: - self.tp_time = None + self.tp_time = 0 try: sck, addr = task @@ -255,7 +264,7 @@ class HttpSrv(object): except: self.log(self.name, "thr_client: " + min_ex(), 3) - def shutdown(self): + def shutdown(self) -> None: self.stopping = True for srv in self.srvs: try: @@ -263,7 +272,7 @@ class HttpSrv(object): except: pass - clients = list(self.clients.keys()) + clients = list(self.clients) for cli in clients: try: cli.shutdown() @@ -279,13 +288,13 @@ class HttpSrv(object): self.log(self.name, "ok bye") - def thr_client(self, sck, addr): + def thr_client(self, sck: socket.socket, addr: tuple[str, int]) -> None: """thread managing one tcp client""" sck.settimeout(120) cli = HttpConn(sck, addr, self) with self.mutex: - self.clients[cli] = 0 + self.clients.add(cli) fno = sck.fileno() try: @@ -328,10 +337,10 @@ class HttpSrv(object): raise finally: with self.mutex: - del self.clients[cli] + self.clients.remove(cli) self.ncli -= 1 - def cachebuster(self): + def cachebuster(self) -> str: if time.time() - self.cb_ts < 1: return self.cb_v diff --git a/copyparty/ico.py b/copyparty/ico.py index 58076c89..f403e4b5 100644 --- a/copyparty/ico.py +++ b/copyparty/ico.py @@ -1,28 +1,28 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import hashlib +import argparse # typechk import colorsys +import hashlib from .__init__ import PY2 class Ico(object): - def __init__(self, args): + def __init__(self, args: argparse.Namespace) -> None: self.args = args - def get(self, ext, as_thumb): + def get(self, ext: str, as_thumb: bool) -> tuple[str, bytes]: """placeholder to make thumbnails not break""" - h = hashlib.md5(ext.encode("utf-8")).digest()[:2] + zb = hashlib.md5(ext.encode("utf-8")).digest()[:2] if PY2: - h = [ord(x) for x in h] + zb = [ord(x) for x in zb] - c1 = colorsys.hsv_to_rgb(h[0] / 256.0, 1, 0.3) - c2 = colorsys.hsv_to_rgb(h[0] / 256.0, 1, 1) - c = list(c1) + list(c2) - c = [int(x * 255) for x in c] - c = "".join(["{:02x}".format(x) for x in c]) + c1 = colorsys.hsv_to_rgb(zb[0] / 256.0, 1, 0.3) + c2 = colorsys.hsv_to_rgb(zb[0] / 256.0, 1, 1) + ci = [int(x * 255) for x in list(c1) + list(c2)] + c = "".join(["{:02x}".format(x) for x in ci]) h = 30 if not self.args.th_no_crop and as_thumb: @@ -37,6 +37,6 @@ class Ico(object): fill="#{}" font-family="monospace" font-size="14px" style="letter-spacing:.5px">{} """ - svg = svg.format(h, c[:6], c[6:], ext).encode("utf-8") + svg = svg.format(h, c[:6], c[6:], ext) - return ["image/svg+xml", svg] + return "image/svg+xml", svg.encode("utf-8") diff --git a/copyparty/mtag.py b/copyparty/mtag.py index 60a7584f..dd1e2ec3 100644 --- a/copyparty/mtag.py +++ b/copyparty/mtag.py @@ -1,18 +1,26 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import os -import sys +import argparse import json +import os import shutil import subprocess as sp +import sys from .__init__ import PY2, WINDOWS, unicode -from .util import fsenc, uncyg, runcmd, retchk, REKOBO_LKEY from .bos import bos +from .util import REKOBO_LKEY, fsenc, retchk, runcmd, uncyg + +try: + from typing import Any, Union + + from .util import RootLogger +except: + pass -def have_ff(cmd): +def have_ff(cmd: str) -> bool: if PY2: print("# checking {}".format(cmd)) cmd = (cmd + " -version").encode("ascii").split(b" ") @@ -30,7 +38,7 @@ HAVE_FFPROBE = have_ff("ffprobe") class MParser(object): - def __init__(self, cmdline): + def __init__(self, cmdline: str) -> None: self.tag, args = cmdline.split("=", 1) self.tags = self.tag.split(",") @@ -73,7 +81,9 @@ class MParser(object): raise Exception() -def ffprobe(abspath, timeout=10): +def ffprobe( + abspath: str, timeout: int = 10 +) -> tuple[dict[str, tuple[int, Any]], dict[str, list[Any]]]: cmd = [ b"ffprobe", b"-hide_banner", @@ -87,15 +97,15 @@ def ffprobe(abspath, timeout=10): return parse_ffprobe(so) -def parse_ffprobe(txt): +def parse_ffprobe(txt: str) -> tuple[dict[str, tuple[int, Any]], dict[str, list[Any]]]: """ffprobe -show_format -show_streams""" streams = [] fmt = {} g = {} for ln in [x.rstrip("\r") for x in txt.split("\n")]: try: - k, v = ln.split("=", 1) - g[k] = v + sk, sv = ln.split("=", 1) + g[sk] = sv continue except: pass @@ -109,8 +119,8 @@ def parse_ffprobe(txt): fmt = g streams = [fmt] + streams - ret = {} # processed - md = {} # raw tags + ret: dict[str, Any] = {} # processed + md: dict[str, list[Any]] = {} # raw tags is_audio = fmt.get("format_name") in ["mp3", "ogg", "flac", "wav"] if fmt.get("filename", "").split(".")[-1].lower() in ["m4a", "aac"]: @@ -161,43 +171,43 @@ def parse_ffprobe(txt): kvm = [["duration", ".dur"], ["bit_rate", ".q"]] for sk, rk in kvm: - v = strm.get(sk) - if v is None: + v1 = strm.get(sk) + if v1 is None: continue if rk.startswith("."): try: - v = float(v) + zf = float(v1) v2 = ret.get(rk) - if v2 is None or v > v2: - ret[rk] = v + if v2 is None or zf > v2: + ret[rk] = zf except: # sqlite doesnt care but the code below does - if v not in ["N/A"]: - ret[rk] = v + if v1 not in ["N/A"]: + ret[rk] = v1 else: - ret[rk] = v + ret[rk] = v1 if ret.get("vc") == "ansi": # shellscript return {}, {} for strm in streams: - for k, v in strm.items(): - if not k.startswith("TAG:"): + for sk, sv in strm.items(): + if not sk.startswith("TAG:"): continue - k = k[4:].strip() - v = v.strip() - if k and v and k not in md: - md[k] = [v] + sk = sk[4:].strip() + sv = sv.strip() + if sk and sv and sk not in md: + md[sk] = [sv] - for k in [".q", ".vq", ".aq"]: - if k in ret: - ret[k] /= 1000 # bit_rate=320000 + for sk in [".q", ".vq", ".aq"]: + if sk in ret: + ret[sk] /= 1000 # bit_rate=320000 - for k in [".q", ".vq", ".aq", ".resw", ".resh"]: - if k in ret: - ret[k] = int(ret[k]) + for sk in [".q", ".vq", ".aq", ".resw", ".resh"]: + if sk in ret: + ret[sk] = int(ret[sk]) if ".fps" in ret: fps = ret[".fps"] @@ -219,13 +229,13 @@ def parse_ffprobe(txt): if ".resw" in ret and ".resh" in ret: ret["res"] = "{}x{}".format(ret[".resw"], ret[".resh"]) - ret = {k: [0, v] for k, v in ret.items()} + zd = {k: (0, v) for k, v in ret.items()} - return ret, md + return zd, md class MTag(object): - def __init__(self, log_func, args): + def __init__(self, log_func: RootLogger, args: argparse.Namespace) -> None: self.log_func = log_func self.args = args self.usable = True @@ -242,7 +252,7 @@ class MTag(object): if self.backend == "mutagen": self.get = self.get_mutagen try: - import mutagen + import mutagen # noqa: F401 # pylint: disable=unused-import,import-outside-toplevel except: self.log("could not load Mutagen, trying FFprobe instead", c=3) self.backend = "ffprobe" @@ -339,31 +349,33 @@ class MTag(object): } # self.get = self.compare - def log(self, msg, c=0): + def log(self, msg: str, c: Union[int, str] = 0) -> None: self.log_func("mtag", msg, c) - def normalize_tags(self, ret, md): - for k, v in dict(md).items(): - if not v: + def normalize_tags( + self, parser_output: dict[str, tuple[int, Any]], md: dict[str, list[Any]] + ) -> dict[str, Union[str, float]]: + for sk, tv in dict(md).items(): + if not tv: continue - k = k.lower().split("::")[0].strip() - mk = self.rmap.get(k) - if not mk: + sk = sk.lower().split("::")[0].strip() + key_mapping = self.rmap.get(sk) + if not key_mapping: continue - pref, mk = mk - if mk not in ret or ret[mk][0] > pref: - ret[mk] = [pref, v[0]] + priority, alias = key_mapping + if alias not in parser_output or parser_output[alias][0] > priority: + parser_output[alias] = (priority, tv[0]) - # take first value - ret = {k: unicode(v[1]).strip() for k, v in ret.items()} + # take first value (lowest priority / most preferred) + ret = {sk: unicode(tv[1]).strip() for sk, tv in parser_output.items()} # track 3/7 => track 3 - for k, v in ret.items(): - if k[0] == ".": - v = v.split("/")[0].strip().lstrip("0") - ret[k] = v or 0 + for sk, tv in ret.items(): + if sk[0] == ".": + sv = str(tv).split("/")[0].strip().lstrip("0") + ret[sk] = sv or 0 # normalize key notation to rkeobo okey = ret.get("key") @@ -373,7 +385,7 @@ class MTag(object): return ret - def compare(self, abspath): + def compare(self, abspath: str) -> dict[str, Union[str, float]]: if abspath.endswith(".au"): return {} @@ -411,7 +423,7 @@ class MTag(object): return r1 - def get_mutagen(self, abspath): + def get_mutagen(self, abspath: str) -> dict[str, Union[str, float]]: if not bos.path.isfile(abspath): return {} @@ -425,7 +437,7 @@ class MTag(object): return self.get_ffprobe(abspath) if self.can_ffprobe else {} sz = bos.path.getsize(abspath) - ret = {".q": [0, int((sz / md.info.length) / 128)]} + ret = {".q": (0, int((sz / md.info.length) / 128))} for attr, k, norm in [ ["codec", "ac", unicode], @@ -456,24 +468,24 @@ class MTag(object): if k == "ac" and v.startswith("mp4a.40."): v = "aac" - ret[k] = [0, norm(v)] + ret[k] = (0, norm(v)) return self.normalize_tags(ret, md) - def get_ffprobe(self, abspath): + def get_ffprobe(self, abspath: str) -> dict[str, Union[str, float]]: if not bos.path.isfile(abspath): return {} ret, md = ffprobe(abspath) return self.normalize_tags(ret, md) - def get_bin(self, parsers, abspath): + def get_bin(self, parsers: dict[str, MParser], abspath: str) -> dict[str, Any]: if not bos.path.isfile(abspath): return {} pypath = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) - pypath = [str(pypath)] + [str(x) for x in sys.path if x] - pypath = str(os.pathsep.join(pypath)) + zsl = [str(pypath)] + [str(x) for x in sys.path if x] + pypath = str(os.pathsep.join(zsl)) env = os.environ.copy() env["PYTHONPATH"] = pypath @@ -491,9 +503,9 @@ class MTag(object): else: cmd = ["nice"] + cmd - cmd = [fsenc(x) for x in cmd] - rc, v, err = runcmd(cmd, **args) - retchk(rc, cmd, err, self.log, 5, self.args.mtag_v) + bcmd = [fsenc(x) for x in cmd] + rc, v, err = runcmd(bcmd, **args) # type: ignore + retchk(rc, bcmd, err, self.log, 5, self.args.mtag_v) v = v.strip() if not v: continue @@ -501,10 +513,10 @@ class MTag(object): if "," not in tagname: ret[tagname] = v else: - v = json.loads(v) + zj = json.loads(v) for tag in tagname.split(","): - if tag and tag in v: - ret[tag] = v[tag] + if tag and tag in zj: + ret[tag] = zj[tag] except: pass diff --git a/copyparty/star.py b/copyparty/star.py index 804ad724..21c2703c 100644 --- a/copyparty/star.py +++ b/copyparty/star.py @@ -4,20 +4,29 @@ from __future__ import print_function, unicode_literals import tarfile import threading -from .sutil import errdesc -from .util import Queue, fsenc, min_ex +from queue import Queue + from .bos import bos +from .sutil import StreamArc, errdesc +from .util import fsenc, min_ex + +try: + from typing import Any, Generator, Optional + + from .util import NamedLogger +except: + pass -class QFile(object): +class QFile(object): # inherit io.StringIO for painful typing """file-like object which buffers writes into a queue""" - def __init__(self): - self.q = Queue(64) - self.bq = [] + def __init__(self) -> None: + self.q: Queue[Optional[bytes]] = Queue(64) + self.bq: list[bytes] = [] self.nq = 0 - def write(self, buf): + def write(self, buf: Optional[bytes]) -> None: if buf is None or self.nq >= 240 * 1024: self.q.put(b"".join(self.bq)) self.bq = [] @@ -30,27 +39,32 @@ class QFile(object): self.nq += len(buf) -class StreamTar(object): +class StreamTar(StreamArc): """construct in-memory tar file from the given path""" - def __init__(self, log, fgen, **kwargs): + def __init__( + self, + log: NamedLogger, + fgen: Generator[dict[str, Any], None, None], + **kwargs: Any + ): + super(StreamTar, self).__init__(log, fgen) + self.ci = 0 self.co = 0 self.qfile = QFile() - self.log = log - self.fgen = fgen - self.errf = None + self.errf: dict[str, Any] = {} # python 3.8 changed to PAX_FORMAT as default, # waste of space and don't care about the new features fmt = tarfile.GNU_FORMAT - self.tar = tarfile.open(fileobj=self.qfile, mode="w|", format=fmt) + self.tar = tarfile.open(fileobj=self.qfile, mode="w|", format=fmt) # type: ignore w = threading.Thread(target=self._gen, name="star-gen") w.daemon = True w.start() - def gen(self): + def gen(self) -> Generator[Optional[bytes], None, None]: while True: buf = self.qfile.q.get() if not buf: @@ -63,7 +77,7 @@ class StreamTar(object): if self.errf: bos.unlink(self.errf["ap"]) - def ser(self, f): + def ser(self, f: dict[str, Any]) -> None: name = f["vp"] src = f["ap"] fsi = f["st"] @@ -76,21 +90,21 @@ class StreamTar(object): inf.gid = 0 self.ci += inf.size - with open(fsenc(src), "rb", 512 * 1024) as f: - self.tar.addfile(inf, f) + with open(fsenc(src), "rb", 512 * 1024) as fo: + self.tar.addfile(inf, fo) - def _gen(self): + def _gen(self) -> None: errors = [] for f in self.fgen: if "err" in f: - errors.append([f["vp"], f["err"]]) + errors.append((f["vp"], f["err"])) continue try: self.ser(f) except: ex = min_ex(5, True).replace("\n", "\n-- ") - errors.append([f["vp"], ex]) + errors.append((f["vp"], ex)) if errors: self.errf, txt = errdesc(errors) diff --git a/copyparty/stolen/surrogateescape.py b/copyparty/stolen/surrogateescape.py index 4b06ed28..b1ff8886 100644 --- a/copyparty/stolen/surrogateescape.py +++ b/copyparty/stolen/surrogateescape.py @@ -12,23 +12,28 @@ Original source: misc/python/surrogateescape.py in https://bitbucket.org/haypo/m # This code is released under the Python license and the BSD 2-clause license -import platform import codecs +import platform import sys PY3 = sys.version_info[0] > 2 WINDOWS = platform.system() == "Windows" FS_ERRORS = "surrogateescape" +try: + from typing import Any +except: + pass -def u(text): + +def u(text: Any) -> str: if PY3: return text else: return text.decode("unicode_escape") -def b(data): +def b(data: Any) -> bytes: if PY3: return data.encode("latin1") else: @@ -43,7 +48,7 @@ else: bytes_chr = chr -def surrogateescape_handler(exc): +def surrogateescape_handler(exc: Any) -> tuple[str, int]: """ Pure Python implementation of the PEP 383: the "surrogateescape" error handler of Python 3. Undecodable bytes will be replaced by a Unicode @@ -74,7 +79,7 @@ class NotASurrogateError(Exception): pass -def replace_surrogate_encode(mystring): +def replace_surrogate_encode(mystring: str) -> str: """ Returns a (unicode) string, not the more logical bytes, because the codecs register_error functionality expects this. @@ -100,7 +105,7 @@ def replace_surrogate_encode(mystring): return str().join(decoded) -def replace_surrogate_decode(mybytes): +def replace_surrogate_decode(mybytes: bytes) -> str: """ Returns a (unicode) string """ @@ -121,7 +126,7 @@ def replace_surrogate_decode(mybytes): return str().join(decoded) -def encodefilename(fn): +def encodefilename(fn: str) -> bytes: if FS_ENCODING == "ascii": # ASCII encoder of Python 2 expects that the error handler returns a # Unicode string encodable to ASCII, whereas our surrogateescape error @@ -161,7 +166,7 @@ def encodefilename(fn): return fn.encode(FS_ENCODING, FS_ERRORS) -def decodefilename(fn): +def decodefilename(fn: bytes) -> str: return fn.decode(FS_ENCODING, FS_ERRORS) @@ -181,7 +186,7 @@ if WINDOWS and not PY3: FS_ENCODING = codecs.lookup(FS_ENCODING).name -def register_surrogateescape(): +def register_surrogateescape() -> None: """ Registers the surrogateescape error handler on Python 2 (only) """ diff --git a/copyparty/sutil.py b/copyparty/sutil.py index 22de0066..506e389f 100644 --- a/copyparty/sutil.py +++ b/copyparty/sutil.py @@ -6,8 +6,29 @@ from datetime import datetime from .bos import bos +try: + from typing import Any, Generator, Optional -def errdesc(errors): + from .util import NamedLogger +except: + pass + + +class StreamArc(object): + def __init__( + self, + log: NamedLogger, + fgen: Generator[dict[str, Any], None, None], + **kwargs: Any + ): + self.log = log + self.fgen = fgen + + def gen(self) -> Generator[Optional[bytes], None, None]: + pass + + +def errdesc(errors: list[tuple[str, str]]) -> tuple[dict[str, Any], list[str]]: report = ["copyparty failed to add the following files to the archive:", ""] for fn, err in errors: diff --git a/copyparty/svchub.py b/copyparty/svchub.py index 44680513..e9b5aadf 100644 --- a/copyparty/svchub.py +++ b/copyparty/svchub.py @@ -1,41 +1,51 @@ # coding: utf-8 from __future__ import print_function, unicode_literals +import argparse +import calendar import os -import sys -import time import shlex -import string import signal import socket +import string +import sys import threading +import time from datetime import datetime, timedelta -import calendar -from .__init__ import E, PY2, WINDOWS, ANYWIN, MACOS, VT100, unicode -from .util import mp, start_log_thrs, start_stackmon, min_ex, ansi_re +try: + from types import FrameType + + import typing + from typing import Optional, Union +except: + pass + +from .__init__ import ANYWIN, MACOS, PY2, VT100, WINDOWS, E, unicode from .authsrv import AuthSrv -from .tcpsrv import TcpSrv -from .up2k import Up2k -from .th_srv import ThumbSrv, HAVE_PIL, HAVE_VIPS, HAVE_WEBP from .mtag import HAVE_FFMPEG, HAVE_FFPROBE +from .tcpsrv import TcpSrv +from .th_srv import HAVE_PIL, HAVE_VIPS, HAVE_WEBP, ThumbSrv +from .up2k import Up2k +from .util import ansi_re, min_ex, mp, start_log_thrs, start_stackmon class SvcHub(object): """ Hosts all services which cannot be parallelized due to reliance on monolithic resources. Creates a Broker which does most of the heavy stuff; hosted services can use this to perform work: - hub.broker.put(want_reply, destination, args_list). + hub.broker.(destination, args_list). Either BrokerThr (plain threads) or BrokerMP (multiprocessing) is used depending on configuration. Nothing is returned synchronously; if you want any value returned from the call, put() can return a queue (if want_reply=True) which has a blocking get() with the response. """ - def __init__(self, args, argv, printed): + def __init__(self, args: argparse.Namespace, argv: list[str], printed: str) -> None: self.args = args self.argv = argv - self.logf = None + self.logf: Optional[typing.TextIO] = None + self.logf_base_fn = "" self.stop_req = False self.reload_req = False self.stopping = False @@ -59,16 +69,16 @@ class SvcHub(object): if not args.use_fpool and args.j != 1: args.no_fpool = True - m = "multithreading enabled with -j {}, so disabling fpool -- this can reduce upload performance on some filesystems" - self.log("root", m.format(args.j)) + t = "multithreading enabled with -j {}, so disabling fpool -- this can reduce upload performance on some filesystems" + self.log("root", t.format(args.j)) if not args.no_fpool and args.j != 1: - m = "WARNING: --use-fpool combined with multithreading is untested and can probably cause undefined behavior" + t = "WARNING: --use-fpool combined with multithreading is untested and can probably cause undefined behavior" if ANYWIN: - m = 'windows cannot do multithreading without --no-fpool, so enabling that -- note that upload performance will suffer if you have microsoft defender "real-time protection" enabled, so you probably want to use -j 1 instead' + t = 'windows cannot do multithreading without --no-fpool, so enabling that -- note that upload performance will suffer if you have microsoft defender "real-time protection" enabled, so you probably want to use -j 1 instead' args.no_fpool = True - self.log("root", m, c=3) + self.log("root", t, c=3) bri = "zy"[args.theme % 2 :][:1] ch = "abcdefghijklmnopqrstuvwx"[int(args.theme / 2)] @@ -96,8 +106,8 @@ class SvcHub(object): self.args.th_dec = list(decs.keys()) self.thumbsrv = None if not args.no_thumb: - m = "decoder preference: {}".format(", ".join(self.args.th_dec)) - self.log("thumb", m) + t = "decoder preference: {}".format(", ".join(self.args.th_dec)) + self.log("thumb", t) if "pil" in self.args.th_dec and not HAVE_WEBP: msg = "disabling webp thumbnails because either libwebp is not available or your Pillow is too old" @@ -131,11 +141,11 @@ class SvcHub(object): if self.check_mp_enable(): from .broker_mp import BrokerMp as Broker else: - from .broker_thr import BrokerThr as Broker + from .broker_thr import BrokerThr as Broker # type: ignore self.broker = Broker(self) - def thr_httpsrv_up(self): + def thr_httpsrv_up(self) -> None: time.sleep(1 if self.args.ign_ebind_all else 5) expected = self.broker.num_workers * self.tcpsrv.nsrv failed = expected - self.httpsrv_up @@ -145,20 +155,20 @@ class SvcHub(object): if self.args.ign_ebind_all: if not self.tcpsrv.srv: for _ in range(self.broker.num_workers): - self.broker.put(False, "cb_httpsrv_up") + self.broker.say("cb_httpsrv_up") return if self.args.ign_ebind and self.tcpsrv.srv: return - m = "{}/{} workers failed to start" - m = m.format(failed, expected) - self.log("root", m, 1) + t = "{}/{} workers failed to start" + t = t.format(failed, expected) + self.log("root", t, 1) self.retcode = 1 os.kill(os.getpid(), signal.SIGTERM) - def cb_httpsrv_up(self): + def cb_httpsrv_up(self) -> None: self.httpsrv_up += 1 if self.httpsrv_up != self.broker.num_workers: return @@ -171,9 +181,9 @@ class SvcHub(object): thr.daemon = True thr.start() - def _logname(self): + def _logname(self) -> str: dt = datetime.utcnow() - fn = self.args.lo + fn = str(self.args.lo) for fs in "YmdHMS": fs = "%" + fs if fs in fn: @@ -181,7 +191,7 @@ class SvcHub(object): return fn - def _setup_logfile(self, printed): + def _setup_logfile(self, printed: str) -> None: base_fn = fn = sel_fn = self._logname() if fn != self.args.lo: ctr = 0 @@ -203,8 +213,6 @@ class SvcHub(object): lh = codecs.open(fn, "w", encoding="utf-8", errors="replace") - lh.base_fn = base_fn - argv = [sys.executable] + self.argv if hasattr(shlex, "quote"): argv = [shlex.quote(x) for x in argv] @@ -215,9 +223,10 @@ class SvcHub(object): printed += msg lh.write("t0: {:.3f}\nargv: {}\n\n{}".format(E.t0, " ".join(argv), printed)) self.logf = lh + self.logf_base_fn = base_fn print(msg, end="") - def run(self): + def run(self) -> None: self.tcpsrv.run() thr = threading.Thread(target=self.thr_httpsrv_up) @@ -252,7 +261,7 @@ class SvcHub(object): else: self.stop_thr() - def reload(self): + def reload(self) -> str: if self.reloading: return "cannot reload; already in progress" @@ -262,7 +271,7 @@ class SvcHub(object): t.start() return "reload initiated" - def _reload(self): + def _reload(self) -> None: self.log("root", "reload scheduled") with self.up2k.mutex: self.asrv.reload() @@ -271,7 +280,7 @@ class SvcHub(object): self.reloading = False - def stop_thr(self): + def stop_thr(self) -> None: while not self.stop_req: with self.stop_cond: self.stop_cond.wait(9001) @@ -282,7 +291,7 @@ class SvcHub(object): self.shutdown() - def signal_handler(self, sig, frame): + def signal_handler(self, sig: int, frame: Optional[FrameType]) -> None: if self.stopping: return @@ -294,7 +303,7 @@ class SvcHub(object): with self.stop_cond: self.stop_cond.notify_all() - def shutdown(self): + def shutdown(self) -> None: if self.stopping: return @@ -337,7 +346,7 @@ class SvcHub(object): sys.exit(ret) - def _log_disabled(self, src, msg, c=0): + def _log_disabled(self, src: str, msg: str, c: Union[int, str] = 0) -> None: if not self.logf: return @@ -349,8 +358,8 @@ class SvcHub(object): if now >= self.next_day: self._set_next_day() - def _set_next_day(self): - if self.next_day and self.logf and self.logf.base_fn != self._logname(): + def _set_next_day(self) -> None: + if self.next_day and self.logf and self.logf_base_fn != self._logname(): self.logf.close() self._setup_logfile("") @@ -364,7 +373,7 @@ class SvcHub(object): dt = dt.replace(hour=0, minute=0, second=0) self.next_day = calendar.timegm(dt.utctimetuple()) - def _log_enabled(self, src, msg, c=0): + def _log_enabled(self, src: str, msg: str, c: Union[int, str] = 0) -> None: """handles logging from all components""" with self.log_mutex: now = time.time() @@ -401,7 +410,7 @@ class SvcHub(object): if self.logf: self.logf.write(msg) - def check_mp_support(self): + def check_mp_support(self) -> str: vmin = sys.version_info[1] if WINDOWS: msg = "need python 3.3 or newer for multiprocessing;" @@ -415,16 +424,16 @@ class SvcHub(object): return msg try: - x = mp.Queue(1) - x.put(["foo", "bar"]) + x: mp.Queue[tuple[str, str]] = mp.Queue(1) + x.put(("foo", "bar")) if x.get()[0] != "foo": raise Exception() except: return "multiprocessing is not supported on your platform;" - return None + return "" - def check_mp_enable(self): + def check_mp_enable(self) -> bool: if self.args.j == 1: return False @@ -447,18 +456,18 @@ class SvcHub(object): self.log("svchub", "cannot efficiently use multiple CPU cores") return False - def sd_notify(self): + def sd_notify(self) -> None: try: - addr = os.getenv("NOTIFY_SOCKET") - if not addr: + zb = os.getenv("NOTIFY_SOCKET") + if not zb: return - addr = unicode(addr) + addr = unicode(zb) if addr.startswith("@"): addr = "\0" + addr[1:] - m = "".join(x for x in addr if x in string.printable) - self.log("sd_notify", m) + t = "".join(x for x in addr if x in string.printable) + self.log("sd_notify", t) sck = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) sck.connect(addr) diff --git a/copyparty/szip.py b/copyparty/szip.py index 177ff3ea..178f8a61 100644 --- a/copyparty/szip.py +++ b/copyparty/szip.py @@ -1,16 +1,23 @@ # coding: utf-8 from __future__ import print_function, unicode_literals +import calendar import time import zlib -import calendar -from .sutil import errdesc -from .util import yieldfile, sanitize_fn, spack, sunpack, min_ex from .bos import bos +from .sutil import StreamArc, errdesc +from .util import min_ex, sanitize_fn, spack, sunpack, yieldfile + +try: + from typing import Any, Generator, Optional + + from .util import NamedLogger +except: + pass -def dostime2unix(buf): +def dostime2unix(buf: bytes) -> int: t, d = sunpack(b" bytes: tt = time.gmtime(ts + 1) dy, dm, dd, th, tm, ts = list(tt)[:6] @@ -41,14 +48,22 @@ def unixtime2dos(ts): return b"\x00\x00\x21\x00" -def gen_fdesc(sz, crc32, z64): +def gen_fdesc(sz: int, crc32: int, z64: bool) -> bytes: ret = b"\x50\x4b\x07\x08" fmt = b" bytes: """ does regular file headers and the central directory meme if h_pos is set @@ -67,8 +82,8 @@ def gen_hdr(h_pos, fn, sz, lastmod, utf8, crc32, pre_crc): # confusingly this doesn't bump if h_pos req_ver = b"\x2d\x00" if z64 else b"\x0a\x00" - if crc32: - crc32 = spack(b" tuple[bytes, bool]: """ summary of all file headers, usually the zipfile footer unless something clamps @@ -154,10 +171,12 @@ def gen_ecdr(items, cdir_pos, cdir_end): # 2b comment length ret += b"\x00\x00" - return [ret, need_64] + return ret, need_64 -def gen_ecdr64(items, cdir_pos, cdir_end): +def gen_ecdr64( + items: list[tuple[str, int, int, int, int]], cdir_pos: int, cdir_end: int +) -> bytes: """ z64 end of central directory added when numfiles or a headerptr clamps @@ -181,7 +200,7 @@ def gen_ecdr64(items, cdir_pos, cdir_end): return ret -def gen_ecdr64_loc(ecdr64_pos): +def gen_ecdr64_loc(ecdr64_pos: int) -> bytes: """ z64 end of central directory locator points to ecdr64 @@ -196,21 +215,27 @@ def gen_ecdr64_loc(ecdr64_pos): return ret -class StreamZip(object): - def __init__(self, log, fgen, utf8=False, pre_crc=False): - self.log = log - self.fgen = fgen +class StreamZip(StreamArc): + def __init__( + self, + log: NamedLogger, + fgen: Generator[dict[str, Any], None, None], + utf8: bool = False, + pre_crc: bool = False, + ) -> None: + super(StreamZip, self).__init__(log, fgen) + self.utf8 = utf8 self.pre_crc = pre_crc self.pos = 0 - self.items = [] + self.items: list[tuple[str, int, int, int, int]] = [] - def _ct(self, buf): + def _ct(self, buf: bytes) -> bytes: self.pos += len(buf) return buf - def ser(self, f): + def ser(self, f: dict[str, Any]) -> Generator[bytes, None, None]: name = f["vp"] src = f["ap"] st = f["st"] @@ -218,9 +243,8 @@ class StreamZip(object): sz = st.st_size ts = st.st_mtime - crc = None + crc = 0 if self.pre_crc: - crc = 0 for buf in yieldfile(src): crc = zlib.crc32(buf, crc) @@ -230,7 +254,6 @@ class StreamZip(object): buf = gen_hdr(None, name, sz, ts, self.utf8, crc, self.pre_crc) yield self._ct(buf) - crc = crc or 0 for buf in yieldfile(src): if not self.pre_crc: crc = zlib.crc32(buf, crc) @@ -239,7 +262,7 @@ class StreamZip(object): crc &= 0xFFFFFFFF - self.items.append([name, sz, ts, crc, h_pos]) + self.items.append((name, sz, ts, crc, h_pos)) z64 = sz >= 4 * 1024 * 1024 * 1024 @@ -247,11 +270,11 @@ class StreamZip(object): buf = gen_fdesc(sz, crc, z64) yield self._ct(buf) - def gen(self): + def gen(self) -> Generator[bytes, None, None]: errors = [] for f in self.fgen: if "err" in f: - errors.append([f["vp"], f["err"]]) + errors.append((f["vp"], f["err"])) continue try: @@ -259,7 +282,7 @@ class StreamZip(object): yield x except: ex = min_ex(5, True).replace("\n", "\n-- ") - errors.append([f["vp"], ex]) + errors.append((f["vp"], ex)) if errors: errf, txt = errdesc(errors) diff --git a/copyparty/tcpsrv.py b/copyparty/tcpsrv.py index 7d2cb3e4..ae4c5f86 100644 --- a/copyparty/tcpsrv.py +++ b/copyparty/tcpsrv.py @@ -2,12 +2,15 @@ from __future__ import print_function, unicode_literals import re -import sys import socket +import sys -from .__init__ import MACOS, ANYWIN, unicode +from .__init__ import ANYWIN, MACOS, TYPE_CHECKING, unicode from .util import chkcmd +if TYPE_CHECKING: + from .svchub import SvcHub + class TcpSrv(object): """ @@ -15,16 +18,16 @@ class TcpSrv(object): which then uses the least busy HttpSrv to handle it """ - def __init__(self, hub): + def __init__(self, hub: "SvcHub"): self.hub = hub self.args = hub.args self.log = hub.log self.stopping = False - self.srv = [] + self.srv: list[socket.socket] = [] self.nsrv = 0 - ok = {} + ok: dict[str, list[int]] = {} for ip in self.args.i: ok[ip] = [] for port in self.args.p: @@ -34,8 +37,8 @@ class TcpSrv(object): ok[ip].append(port) except Exception as ex: if self.args.ign_ebind or self.args.ign_ebind_all: - m = "could not listen on {}:{}: {}" - self.log("tcpsrv", m.format(ip, port, ex), c=3) + t = "could not listen on {}:{}: {}" + self.log("tcpsrv", t.format(ip, port, ex), c=3) else: raise @@ -55,9 +58,9 @@ class TcpSrv(object): eps[x] = "external" msgs = [] - title_tab = {} + title_tab: dict[str, dict[str, int]] = {} title_vars = [x[1:] for x in self.args.wintitle.split(" ") if x.startswith("$")] - m = "available @ {}://{}:{}/ (\033[33m{}\033[0m)" + t = "available @ {}://{}:{}/ (\033[33m{}\033[0m)" for ip, desc in sorted(eps.items(), key=lambda x: x[1]): for port in sorted(self.args.p): if port not in ok.get(ip, ok.get("0.0.0.0", [])): @@ -69,7 +72,7 @@ class TcpSrv(object): elif self.args.https_only or port == 443: proto = "https" - msgs.append(m.format(proto, ip, port, desc)) + msgs.append(t.format(proto, ip, port, desc)) if not self.args.wintitle: continue @@ -98,13 +101,13 @@ class TcpSrv(object): if msgs: msgs[-1] += "\n" - for m in msgs: - self.log("tcpsrv", m) + for t in msgs: + self.log("tcpsrv", t) if self.args.wintitle: self._set_wintitle(title_tab) - def _listen(self, ip, port): + def _listen(self, ip: str, port: int) -> None: srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) srv.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @@ -120,7 +123,7 @@ class TcpSrv(object): raise raise Exception(e) - def run(self): + def run(self) -> None: for srv in self.srv: srv.listen(self.args.nc) ip, port = srv.getsockname() @@ -130,9 +133,9 @@ class TcpSrv(object): if self.args.q: print(msg) - self.hub.broker.put(False, "listen", srv) + self.hub.broker.say("listen", srv) - def shutdown(self): + def shutdown(self) -> None: self.stopping = True try: for srv in self.srv: @@ -142,14 +145,14 @@ class TcpSrv(object): self.log("tcpsrv", "ok bye") - def ips_linux_ifconfig(self): + def ips_linux_ifconfig(self) -> dict[str, str]: # for termux try: txt, _ = chkcmd(["ifconfig"]) except: return {} - eps = {} + eps: dict[str, str] = {} dev = None ip = None up = None @@ -171,7 +174,7 @@ class TcpSrv(object): return eps - def ips_linux(self): + def ips_linux(self) -> dict[str, str]: try: txt, _ = chkcmd(["ip", "addr"]) except: @@ -180,21 +183,21 @@ class TcpSrv(object): r = re.compile(r"^\s+inet ([^ ]+)/.* (.*)") ri = re.compile(r"^\s*[0-9]+\s*:.*") up = False - eps = {} + eps: dict[str, str] = {} for ln in txt.split("\n"): if ri.match(ln): up = "UP" in re.split("[>,< ]", ln) try: - ip, dev = r.match(ln.rstrip()).groups() + ip, dev = r.match(ln.rstrip()).groups() # type: ignore eps[ip] = dev + ("" if up else ", \033[31mLINK-DOWN") except: pass return eps - def ips_macos(self): - eps = {} + def ips_macos(self) -> dict[str, str]: + eps: dict[str, str] = {} try: txt, _ = chkcmd(["ifconfig"]) except: @@ -202,7 +205,7 @@ class TcpSrv(object): rdev = re.compile(r"^([^ ]+):") rip = re.compile(r"^\tinet ([0-9\.]+) ") - dev = None + dev = "UNKNOWN" for ln in txt.split("\n"): m = rdev.match(ln) if m: @@ -211,17 +214,17 @@ class TcpSrv(object): m = rip.match(ln) if m: eps[m.group(1)] = dev - dev = None + dev = "UNKNOWN" return eps - def ips_windows_ipconfig(self): - eps = {} - offs = {} + def ips_windows_ipconfig(self) -> tuple[dict[str, str], set[str]]: + eps: dict[str, str] = {} + offs: set[str] = set() try: txt, _ = chkcmd(["ipconfig"]) except: - return eps + return eps, offs rdev = re.compile(r"(^[^ ].*):$") rip = re.compile(r"^ +IPv?4? [^:]+: *([0-9\.]{7,15})$") @@ -231,12 +234,12 @@ class TcpSrv(object): m = rdev.match(ln) if m: if dev and dev not in eps.values(): - offs[dev] = 1 + offs.add(dev) dev = m.group(1).split(" adapter ", 1)[-1] if dev and roff.match(ln): - offs[dev] = 1 + offs.add(dev) dev = None m = rip.match(ln) @@ -245,12 +248,12 @@ class TcpSrv(object): dev = None if dev and dev not in eps.values(): - offs[dev] = 1 + offs.add(dev) return eps, offs - def ips_windows_netsh(self): - eps = {} + def ips_windows_netsh(self) -> dict[str, str]: + eps: dict[str, str] = {} try: txt, _ = chkcmd("netsh interface ip show address".split()) except: @@ -270,7 +273,7 @@ class TcpSrv(object): return eps - def detect_interfaces(self, listen_ips): + def detect_interfaces(self, listen_ips: list[str]) -> dict[str, str]: if MACOS: eps = self.ips_macos() elif ANYWIN: @@ -317,7 +320,7 @@ class TcpSrv(object): return eps - def _set_wintitle(self, vs): + def _set_wintitle(self, vs: dict[str, dict[str, int]]) -> None: vs["all"] = vs.get("all", {"Local-Only": 1}) vs["pub"] = vs.get("pub", vs["all"]) diff --git a/copyparty/th_cli.py b/copyparty/th_cli.py index e8a18e0e..9eb49a4f 100644 --- a/copyparty/th_cli.py +++ b/copyparty/th_cli.py @@ -3,13 +3,23 @@ from __future__ import print_function, unicode_literals import os -from .util import Cooldown -from .th_srv import thumb_path, HAVE_WEBP +from .__init__ import TYPE_CHECKING +from .authsrv import VFS from .bos import bos +from .th_srv import HAVE_WEBP, thumb_path +from .util import Cooldown + +try: + from typing import Optional, Union +except: + pass + +if TYPE_CHECKING: + from .httpsrv import HttpSrv class ThumbCli(object): - def __init__(self, hsrv): + def __init__(self, hsrv: "HttpSrv") -> None: self.broker = hsrv.broker self.log_func = hsrv.log self.args = hsrv.args @@ -34,10 +44,10 @@ class ThumbCli(object): d = next((x for x in self.args.th_dec if x in ("vips", "pil")), None) self.can_webp = HAVE_WEBP or d == "vips" - def log(self, msg, c=0): + def log(self, msg: str, c: Union[int, str] = 0) -> None: self.log_func("thumbcli", msg, c) - def get(self, dbv, rem, mtime, fmt): + def get(self, dbv: VFS, rem: str, mtime: float, fmt: str) -> Optional[str]: ptop = dbv.realpath ext = rem.rsplit(".")[-1].lower() if ext not in self.thumbable or "dthumb" in dbv.flags: @@ -106,17 +116,17 @@ class ThumbCli(object): if ret: tdir = os.path.dirname(tpath) if self.cooldown.poke(tdir): - self.broker.put(False, "thumbsrv.poke", tdir) + self.broker.say("thumbsrv.poke", tdir) if want_opus: # audio files expire individually if self.cooldown.poke(tpath): - self.broker.put(False, "thumbsrv.poke", tpath) + self.broker.say("thumbsrv.poke", tpath) return ret if abort: return None - x = self.broker.put(True, "thumbsrv.get", ptop, rem, mtime, fmt) - return x.get() + x = self.broker.ask("thumbsrv.get", ptop, rem, mtime, fmt) + return x.get() # type: ignore diff --git a/copyparty/th_srv.py b/copyparty/th_srv.py index dc3965c7..2a9e5f02 100644 --- a/copyparty/th_srv.py +++ b/copyparty/th_srv.py @@ -1,18 +1,28 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import os -import time -import shutil import base64 import hashlib -import threading +import os +import shutil import subprocess as sp +import threading +import time -from .util import fsenc, vsplit, statdir, runcmd, Queue, Cooldown, BytesIO, min_ex +from queue import Queue + +from .__init__ import TYPE_CHECKING from .bos import bos from .mtag import HAVE_FFMPEG, HAVE_FFPROBE, ffprobe +from .util import BytesIO, Cooldown, fsenc, min_ex, runcmd, statdir, vsplit +try: + from typing import Optional, Union +except: + pass + +if TYPE_CHECKING: + from .svchub import SvcHub HAVE_PIL = False HAVE_HEIF = False @@ -20,7 +30,7 @@ HAVE_AVIF = False HAVE_WEBP = False try: - from PIL import Image, ImageOps, ExifTags + from PIL import ExifTags, Image, ImageOps HAVE_PIL = True try: @@ -47,14 +57,13 @@ except: pass try: - import pyvips - HAVE_VIPS = True + import pyvips except: HAVE_VIPS = False -def thumb_path(histpath, rem, mtime, fmt): +def thumb_path(histpath: str, rem: str, mtime: float, fmt: str) -> str: # base16 = 16 = 256 # b64-lc = 38 = 1444 # base64 = 64 = 4096 @@ -80,7 +89,7 @@ def thumb_path(histpath, rem, mtime, fmt): class ThumbSrv(object): - def __init__(self, hub): + def __init__(self, hub: "SvcHub") -> None: self.hub = hub self.asrv = hub.asrv self.args = hub.args @@ -91,17 +100,17 @@ class ThumbSrv(object): self.poke_cd = Cooldown(self.args.th_poke) self.mutex = threading.Lock() - self.busy = {} + self.busy: dict[str, list[threading.Condition]] = {} self.stopping = False self.nthr = max(1, self.args.th_mt) - self.q = Queue(self.nthr * 4) + self.q: Queue[Optional[tuple[str, str]]] = Queue(self.nthr * 4) for n in range(self.nthr): - t = threading.Thread( + thr = threading.Thread( target=self.worker, name="thumb-{}-{}".format(n, self.nthr) ) - t.daemon = True - t.start() + thr.daemon = True + thr.start() want_ff = not self.args.no_vthumb or not self.args.no_athumb if want_ff and (not HAVE_FFMPEG or not HAVE_FFPROBE): @@ -122,7 +131,7 @@ class ThumbSrv(object): t.start() self.fmt_pil, self.fmt_vips, self.fmt_ffi, self.fmt_ffv, self.fmt_ffa = [ - {x: True for x in y.split(",")} + set(y.split(",")) for y in [ self.args.th_r_pil, self.args.th_r_vips, @@ -134,37 +143,37 @@ class ThumbSrv(object): if not HAVE_HEIF: for f in "heif heifs heic heics".split(" "): - self.fmt_pil.pop(f, None) + self.fmt_pil.discard(f) if not HAVE_AVIF: for f in "avif avifs".split(" "): - self.fmt_pil.pop(f, None) + self.fmt_pil.discard(f) - self.thumbable = {} + self.thumbable: set[str] = set() if "pil" in self.args.th_dec: - self.thumbable.update(self.fmt_pil) + self.thumbable |= self.fmt_pil if "vips" in self.args.th_dec: - self.thumbable.update(self.fmt_vips) + self.thumbable |= self.fmt_vips if "ff" in self.args.th_dec: - for t in [self.fmt_ffi, self.fmt_ffv, self.fmt_ffa]: - self.thumbable.update(t) + for zss in [self.fmt_ffi, self.fmt_ffv, self.fmt_ffa]: + self.thumbable |= zss - def log(self, msg, c=0): + def log(self, msg: str, c: Union[int, str] = 0) -> None: self.log_func("thumb", msg, c) - def shutdown(self): + def shutdown(self) -> None: self.stopping = True for _ in range(self.nthr): self.q.put(None) - def stopped(self): + def stopped(self) -> bool: with self.mutex: return not self.nthr - def get(self, ptop, rem, mtime, fmt): + def get(self, ptop: str, rem: str, mtime: float, fmt: str) -> Optional[str]: histpath = self.asrv.vfs.histtab.get(ptop) if not histpath: self.log("no histpath for [{}]".format(ptop)) @@ -191,7 +200,7 @@ class ThumbSrv(object): do_conv = True if do_conv: - self.q.put([abspath, tpath]) + self.q.put((abspath, tpath)) self.log("conv {} \033[0m{}".format(tpath, abspath), c=6) while not self.stopping: @@ -212,7 +221,7 @@ class ThumbSrv(object): return None - def getcfg(self): + def getcfg(self) -> dict[str, set[str]]: return { "thumbable": self.thumbable, "pil": self.fmt_pil, @@ -222,7 +231,7 @@ class ThumbSrv(object): "ffa": self.fmt_ffa, } - def worker(self): + def worker(self) -> None: while not self.stopping: task = self.q.get() if not task: @@ -253,7 +262,7 @@ class ThumbSrv(object): except: msg = "{} could not create thumbnail of {}\n{}" msg = msg.format(fun.__name__, abspath, min_ex()) - c = 1 if " "Image.Image": # exif_transpose is expensive (loads full image + unconditional copy) r = max(*self.res) * 2 im.thumbnail((r, r), resample=Image.LANCZOS) @@ -295,7 +304,7 @@ class ThumbSrv(object): return im - def conv_pil(self, abspath, tpath): + def conv_pil(self, abspath: str, tpath: str) -> None: with Image.open(fsenc(abspath)) as im: try: im = self.fancy_pillow(im) @@ -324,7 +333,7 @@ class ThumbSrv(object): im.save(tpath, **args) - def conv_vips(self, abspath, tpath): + def conv_vips(self, abspath: str, tpath: str) -> None: crops = ["centre", "none"] if self.args.th_no_crop: crops = ["none"] @@ -342,18 +351,17 @@ class ThumbSrv(object): img.write_to_file(tpath, Q=40) - def conv_ffmpeg(self, abspath, tpath): + def conv_ffmpeg(self, abspath: str, tpath: str) -> None: ret, _ = ffprobe(abspath) if not ret: return ext = abspath.rsplit(".")[-1].lower() if ext in ["h264", "h265"] or ext in self.fmt_ffi: - seek = [] + seek: list[bytes] = [] else: dur = ret[".dur"][1] if ".dur" in ret else 4 - seek = "{:.0f}".format(dur / 3) - seek = [b"-ss", seek.encode("utf-8")] + seek = [b"-ss", "{:.0f}".format(dur / 3).encode("utf-8")] scale = "scale={0}:{1}:force_original_aspect_ratio=" if self.args.th_no_crop: @@ -361,7 +369,7 @@ class ThumbSrv(object): else: scale += "increase,crop={0}:{1},setsar=1:1" - scale = scale.format(*list(self.res)).encode("utf-8") + bscale = scale.format(*list(self.res)).encode("utf-8") # fmt: off cmd = [ b"ffmpeg", @@ -373,7 +381,7 @@ class ThumbSrv(object): cmd += [ b"-i", fsenc(abspath), b"-map", b"0:v:0", - b"-vf", scale, + b"-vf", bscale, b"-frames:v", b"1", b"-metadata:s:v:0", b"rotate=0", ] @@ -395,14 +403,14 @@ class ThumbSrv(object): cmd += [fsenc(tpath)] self._run_ff(cmd) - def _run_ff(self, cmd): + def _run_ff(self, cmd: list[bytes]) -> None: # self.log((b" ".join(cmd)).decode("utf-8")) ret, _, serr = runcmd(cmd, timeout=self.args.th_convt) if not ret: return - c = "1;30" - m = "FFmpeg failed (probably a corrupt video file):\n" + c: Union[str, int] = "1;30" + t = "FFmpeg failed (probably a corrupt video file):\n" if cmd[-1].lower().endswith(b".webp") and ( "Error selecting an encoder" in serr or "Automatic encoder selection failed" in serr @@ -410,14 +418,14 @@ class ThumbSrv(object): or "Please choose an encoder manually" in serr ): self.args.th_ff_jpg = True - m = "FFmpeg failed because it was compiled without libwebp; enabling --th-ff-jpg to force jpeg output:\n" + t = "FFmpeg failed because it was compiled without libwebp; enabling --th-ff-jpg to force jpeg output:\n" c = 1 if ( "Requested resampling engine is unavailable" in serr or "output pad on Parsed_aresample_" in serr ): - m = "FFmpeg failed because it was compiled without libsox; you must set --th-ff-swr to force swr resampling:\n" + t = "FFmpeg failed because it was compiled without libsox; you must set --th-ff-swr to force swr resampling:\n" c = 1 lines = serr.strip("\n").split("\n") @@ -428,10 +436,10 @@ class ThumbSrv(object): if len(txt) > 5000: txt = txt[:2500] + "...\nff: [...]\nff: ..." + txt[-2500:] - self.log(m + txt, c=c) + self.log(t + txt, c=c) raise sp.CalledProcessError(ret, (cmd[0], b"...", cmd[-1])) - def conv_spec(self, abspath, tpath): + def conv_spec(self, abspath: str, tpath: str) -> None: ret, _ = ffprobe(abspath) if "ac" not in ret: raise Exception("not audio") @@ -473,7 +481,7 @@ class ThumbSrv(object): cmd += [fsenc(tpath)] self._run_ff(cmd) - def conv_opus(self, abspath, tpath): + def conv_opus(self, abspath: str, tpath: str) -> None: if self.args.no_acode: raise Exception("disabled in server config") @@ -521,7 +529,7 @@ class ThumbSrv(object): # fmt: on self._run_ff(cmd) - def poke(self, tdir): + def poke(self, tdir: str) -> None: if not self.poke_cd.poke(tdir): return @@ -533,7 +541,7 @@ class ThumbSrv(object): except: pass - def cleaner(self): + def cleaner(self) -> None: interval = self.args.th_clean while True: time.sleep(interval) @@ -548,14 +556,14 @@ class ThumbSrv(object): self.log("\033[Jcln ok; rm {} dirs".format(ndirs)) - def clean(self, histpath): + def clean(self, histpath: str) -> int: ret = 0 for cat in ["th", "ac"]: - ret += self._clean(histpath, cat, None) + ret += self._clean(histpath, cat, "") return ret - def _clean(self, histpath, cat, thumbpath): + def _clean(self, histpath: str, cat: str, thumbpath: str) -> int: if not thumbpath: thumbpath = os.path.join(histpath, cat) @@ -564,10 +572,10 @@ class ThumbSrv(object): maxage = getattr(self.args, cat + "_maxage") now = time.time() prev_b64 = None - prev_fp = None + prev_fp = "" try: - ents = statdir(self.log, not self.args.no_scandir, False, thumbpath) - ents = sorted(list(ents)) + t1 = statdir(self.log_func, not self.args.no_scandir, False, thumbpath) + ents = sorted(list(t1)) except: return 0 diff --git a/copyparty/u2idx.py b/copyparty/u2idx.py index 073ed3d1..8a342042 100644 --- a/copyparty/u2idx.py +++ b/copyparty/u2idx.py @@ -1,34 +1,37 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import re -import os -import time import calendar +import os +import re import threading +import time from operator import itemgetter -from .__init__ import ANYWIN, unicode -from .util import absreal, s3dec, Pebkac, min_ex, gen_filekey, quotep +from .__init__ import ANYWIN, TYPE_CHECKING, unicode from .bos import bos from .up2k import up2k_wark_from_hashlist +from .util import HAVE_SQLITE3, Pebkac, absreal, gen_filekey, min_ex, quotep, s3dec - -try: - HAVE_SQLITE3 = True +if HAVE_SQLITE3: import sqlite3 -except: - HAVE_SQLITE3 = False - try: from pathlib import Path except: pass +try: + from typing import Any, Optional, Union +except: + pass + +if TYPE_CHECKING: + from .httpconn import HttpConn + class U2idx(object): - def __init__(self, conn): + def __init__(self, conn: "HttpConn") -> None: self.log_func = conn.log_func self.asrv = conn.asrv self.args = conn.args @@ -38,19 +41,21 @@ class U2idx(object): self.log("your python does not have sqlite3; searching will be disabled") return - self.active_id = None - self.active_cur = None - self.cur = {} - self.mem_cur = sqlite3.connect(":memory:") + self.active_id = "" + self.active_cur: Optional["sqlite3.Cursor"] = None + self.cur: dict[str, "sqlite3.Cursor"] = {} + self.mem_cur = sqlite3.connect(":memory:").cursor() self.mem_cur.execute(r"create table a (b text)") - self.p_end = None - self.p_dur = 0 + self.p_end = 0.0 + self.p_dur = 0.0 - def log(self, msg, c=0): + def log(self, msg: str, c: Union[int, str] = 0) -> None: self.log_func("u2idx", msg, c) - def fsearch(self, vols, body): + def fsearch( + self, vols: list[tuple[str, str, dict[str, Any]]], body: dict[str, Any] + ) -> list[dict[str, Any]]: """search by up2k hashlist""" if not HAVE_SQLITE3: return [] @@ -60,14 +65,14 @@ class U2idx(object): wark = up2k_wark_from_hashlist(self.args.salt, fsize, fhash) uq = "substr(w,1,16) = ? and w = ?" - uv = [wark[:16], wark] + uv: list[Union[str, int]] = [wark[:16], wark] try: return self.run_query(vols, uq, uv, True, False, 99999)[0] except: raise Pebkac(500, min_ex()) - def get_cur(self, ptop): + def get_cur(self, ptop: str) -> Optional["sqlite3.Cursor"]: if not HAVE_SQLITE3: return None @@ -103,13 +108,16 @@ class U2idx(object): self.cur[ptop] = cur return cur - def search(self, vols, uq, lim): + def search( + self, vols: list[tuple[str, str, dict[str, Any]]], uq: str, lim: int + ) -> tuple[list[dict[str, Any]], list[str]]: """search by query params""" if not HAVE_SQLITE3: - return [] + return [], [] q = "" - va = [] + v: Union[str, int] = "" + va: list[Union[str, int]] = [] have_up = False # query has up.* operands have_mt = False is_key = True @@ -202,7 +210,7 @@ class U2idx(object): "%Y", ]: try: - v = calendar.timegm(time.strptime(v, fmt)) + v = calendar.timegm(time.strptime(str(v), fmt)) break except: pass @@ -230,11 +238,12 @@ class U2idx(object): # lowercase tag searches m = ptn_lc.search(q) - if not m or not ptn_lcv.search(unicode(v)): + zs = unicode(v) + if not m or not ptn_lcv.search(zs): continue va.pop() - va.append(v.lower()) + va.append(zs.lower()) q = q[: m.start()] field, oper = m.groups() @@ -248,8 +257,16 @@ class U2idx(object): except Exception as ex: raise Pebkac(500, repr(ex)) - def run_query(self, vols, uq, uv, have_up, have_mt, lim): - done_flag = [] + def run_query( + self, + vols: list[tuple[str, str, dict[str, Any]]], + uq: str, + uv: list[Union[str, int]], + have_up: bool, + have_mt: bool, + lim: int, + ) -> tuple[list[dict[str, Any]], list[str]]: + done_flag: list[bool] = [] self.active_id = "{:.6f}_{}".format( time.time(), threading.current_thread().ident ) @@ -266,13 +283,11 @@ class U2idx(object): if not uq or not uv: uq = "select * from up" - uv = () + uv = [] elif have_mt: uq = "select up.*, substr(up.w,1,16) mtw from up where " + uq - uv = tuple(uv) else: uq = "select up.* from up where " + uq - uv = tuple(uv) self.log("qs: {!r} {!r}".format(uq, uv)) @@ -292,11 +307,10 @@ class U2idx(object): v = vtop + "/" vuv.append(v) - vuv = tuple(vuv) sret = [] fk = flags.get("fk") - c = cur.execute(uq, vuv) + c = cur.execute(uq, tuple(vuv)) for hit in c: w, ts, sz, rd, fn, ip, at = hit[:7] lim -= 1 @@ -340,7 +354,7 @@ class U2idx(object): # print("[{}] {}".format(ptop, sret)) done_flag.append(True) - self.active_id = None + self.active_id = "" # undupe hits from multiple metadata keys if len(ret) > 1: @@ -354,11 +368,12 @@ class U2idx(object): return ret, list(taglist.keys()) - def terminator(self, identifier, done_flag): + def terminator(self, identifier: str, done_flag: list[bool]) -> None: for _ in range(self.timeout): time.sleep(1) if done_flag: return if identifier == self.active_id: + assert self.active_cur self.active_cur.connection.interrupt() diff --git a/copyparty/up2k.py b/copyparty/up2k.py index 3f4e02e4..aff68917 100644 --- a/copyparty/up2k.py +++ b/copyparty/up2k.py @@ -1,60 +1,83 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import re -import os -import time -import math -import json -import gzip -import stat -import shutil import base64 +import gzip import hashlib -import threading -import traceback +import json +import math +import os +import re +import shutil +import stat import subprocess as sp +import threading +import time +import traceback from copy import deepcopy -from .__init__ import WINDOWS, ANYWIN, PY2 -from .util import ( - Pebkac, - Queue, - ProgressPrinter, - SYMTIME, - fsenc, - absreal, - sanitize_fn, - ren_open, - atomic_move, - quotep, - vsplit, - w8b64enc, - w8b64dec, - s3enc, - s3dec, - rmdirs, - statdir, - s2hms, - min_ex, -) -from .bos import bos -from .authsrv import AuthSrv, LEELOO_DALLAS -from .mtag import MTag, MParser +from queue import Queue -try: - HAVE_SQLITE3 = True +from .__init__ import ANYWIN, PY2, TYPE_CHECKING, WINDOWS +from .authsrv import LEELOO_DALLAS, VFS, AuthSrv +from .bos import bos +from .mtag import MParser, MTag +from .util import ( + HAVE_SQLITE3, + SYMTIME, + Pebkac, + ProgressPrinter, + absreal, + atomic_move, + fsenc, + min_ex, + quotep, + ren_open, + rmdirs, + s2hms, + s3dec, + s3enc, + sanitize_fn, + statdir, + vsplit, + w8b64dec, + w8b64enc, +) + +if HAVE_SQLITE3: import sqlite3 -except: - HAVE_SQLITE3 = False DB_VER = 5 +try: + from typing import Any, Optional, Pattern, Union +except: + pass + +if TYPE_CHECKING: + from .svchub import SvcHub + + +class Dbw(object): + def __init__(self, c: "sqlite3.Cursor", n: int, t: float) -> None: + self.c = c + self.n = n + self.t = t + + +class Mpqe(object): + def __init__(self, mtp: dict[str, MParser], entags: set[str], w: str, abspath: str): + # mtp empty = mtag + self.mtp = mtp + self.entags = entags + self.w = w + self.abspath = abspath + class Up2k(object): - def __init__(self, hub): + def __init__(self, hub: "SvcHub") -> None: self.hub = hub - self.asrv = hub.asrv # type: AuthSrv + self.asrv: AuthSrv = hub.asrv self.args = hub.args self.log_func = hub.log @@ -62,25 +85,32 @@ class Up2k(object): self.salt = self.args.salt # state + self.gid = 0 self.mutex = threading.Lock() + self.pp: Optional[ProgressPrinter] = None self.rescan_cond = threading.Condition() - self.hashq = Queue() - self.tagq = Queue() + self.need_rescan: set[str] = set() + + self.registry: dict[str, dict[str, dict[str, Any]]] = {} + self.flags: dict[str, dict[str, Any]] = {} + self.droppable: dict[str, list[str]] = {} + self.volstate: dict[str, str] = {} + self.dupesched: dict[str, list[tuple[str, str, float]]] = {} + self.snap_persist_interval = 300 # persist unfinished index every 5 min + self.snap_discard_interval = 21600 # drop unfinished after 6 hours inactivity + self.snap_prev: dict[str, Optional[tuple[int, float]]] = {} + + self.mtag: Optional[MTag] = None + self.entags: dict[str, set[str]] = {} + self.mtp_parsers: dict[str, dict[str, MParser]] = {} + self.pending_tags: list[tuple[set[str], str, str, dict[str, Any]]] = [] + self.hashq: Queue[tuple[str, str, str, str, float]] = Queue() + self.tagq: Queue[tuple[str, str, str, str]] = Queue() self.n_hashq = 0 self.n_tagq = 0 - self.gid = 0 - self.volstate = {} - self.need_rescan = {} - self.dupesched = {} - self.registry = {} - self.droppable = {} - self.entags = {} - self.flags = {} - self.cur = {} - self.mtag = None - self.pending_tags = None - self.mtp_parsers = {} + self.mpool_used = False + self.cur: dict[str, "sqlite3.Cursor"] = {} self.mem_cur = None self.sqlite_ver = None self.no_expr_idx = False @@ -94,7 +124,7 @@ class Up2k(object): if ANYWIN: # usually fails to set lastmod too quickly - self.lastmod_q = [] + self.lastmod_q: list[tuple[str, int, tuple[int, int]]] = [] thr = threading.Thread(target=self._lastmodder, name="up2k-lastmod") thr.daemon = True thr.start() @@ -108,7 +138,7 @@ class Up2k(object): if self.args.no_fastboot: self.deferred_init() - def init_vols(self): + def init_vols(self) -> None: if self.args.no_fastboot: return @@ -116,15 +146,15 @@ class Up2k(object): t.daemon = True t.start() - def reload(self): + def reload(self) -> None: self.gid += 1 self.log("reload #{} initiated".format(self.gid)) all_vols = self.asrv.vfs.all_vols self.rescan(all_vols, list(all_vols.keys()), True) - def deferred_init(self): + def deferred_init(self) -> None: all_vols = self.asrv.vfs.all_vols - have_e2d = self.init_indexes(all_vols) + have_e2d = self.init_indexes(all_vols, []) thr = threading.Thread(target=self._snapshot, name="up2k-snapshot") thr.daemon = True @@ -150,11 +180,11 @@ class Up2k(object): thr.daemon = True thr.start() - def log(self, msg, c=0): + def log(self, msg: str, c: Union[int, str] = 0) -> None: self.log_func("up2k", msg + "\033[K", c) - def get_state(self): - mtpq = 0 + def get_state(self) -> str: + mtpq: Union[int, str] = 0 q = "select count(w) from mt where k = 't:mtp'" got_lock = False if PY2 else self.mutex.acquire(timeout=0.5) if got_lock: @@ -165,19 +195,19 @@ class Up2k(object): pass self.mutex.release() else: - mtpq = "?" + mtpq = "(?)" ret = { "volstate": self.volstate, - "scanning": hasattr(self, "pp"), + "scanning": bool(self.pp), "hashq": self.n_hashq, "tagq": self.n_tagq, "mtpq": mtpq, } return json.dumps(ret, indent=4) - def rescan(self, all_vols, scan_vols, wait): - if not wait and hasattr(self, "pp"): + def rescan(self, all_vols: dict[str, VFS], scan_vols: list[str], wait: bool) -> str: + if not wait and self.pp: return "cannot initiate; scan is already in progress" args = (all_vols, scan_vols) @@ -188,11 +218,11 @@ class Up2k(object): ) t.daemon = True t.start() - return None + return "" - def _sched_rescan(self): + def _sched_rescan(self) -> None: volage = {} - cooldown = 0 + cooldown = 0.0 timeout = time.time() + 3 while True: timeout = max(timeout, cooldown) @@ -204,7 +234,7 @@ class Up2k(object): if now < cooldown: continue - if hasattr(self, "pp"): + if self.pp: cooldown = now + 5 continue @@ -220,19 +250,19 @@ class Up2k(object): deadline = volage[vp] + maxage if deadline <= now: - self.need_rescan[vp] = 1 + self.need_rescan.add(vp) timeout = min(timeout, deadline) - vols = list(sorted(self.need_rescan.keys())) - self.need_rescan = {} + vols = list(sorted(self.need_rescan)) + self.need_rescan.clear() if vols: cooldown = now + 10 err = self.rescan(self.asrv.vfs.all_vols, vols, False) if err: for v in vols: - self.need_rescan[v] = True + self.need_rescan.add(v) continue @@ -272,7 +302,7 @@ class Up2k(object): if vp: fvp = "{}/{}".format(vp, fvp) - self._handle_rm(LEELOO_DALLAS, None, fvp) + self._handle_rm(LEELOO_DALLAS, "", fvp) nrm += 1 if nrm: @@ -288,12 +318,12 @@ class Up2k(object): if hits: timeout = min(timeout, now + lifetime - (now - hits[0])) - def _vis_job_progress(self, job): + def _vis_job_progress(self, job: dict[str, Any]) -> str: perc = 100 - (len(job["need"]) * 100.0 / len(job["hash"])) path = os.path.join(job["ptop"], job["prel"], job["name"]) return "{:5.1f}% {}".format(perc, path) - def _vis_reg_progress(self, reg): + def _vis_reg_progress(self, reg: dict[str, dict[str, Any]]) -> list[str]: ret = [] for _, job in reg.items(): if job["need"]: @@ -301,7 +331,7 @@ class Up2k(object): return ret - def _expr_idx_filter(self, flags): + def _expr_idx_filter(self, flags: dict[str, Any]) -> tuple[bool, dict[str, Any]]: if not self.no_expr_idx: return False, flags @@ -311,19 +341,19 @@ class Up2k(object): return True, ret - def init_indexes(self, all_vols, scan_vols=None): + def init_indexes(self, all_vols: dict[str, VFS], scan_vols: list[str]) -> bool: gid = self.gid - while hasattr(self, "pp") and gid == self.gid: + while self.pp and gid == self.gid: time.sleep(0.1) if gid != self.gid: - return + return False if gid: self.log("reload #{} running".format(self.gid)) self.pp = ProgressPrinter() - vols = all_vols.values() + vols = list(all_vols.values()) t0 = time.time() have_e2d = False @@ -377,9 +407,9 @@ class Up2k(object): # e2ds(a) volumes first for vol in vols: - en = {} + en: set[str] = set() if "mte" in vol.flags: - en = {k: True for k in vol.flags["mte"].split(",")} + en = set(vol.flags["mte"].split(",")) self.entags[vol.realpath] = en @@ -393,11 +423,11 @@ class Up2k(object): need_vac[vol] = True if "e2ts" not in vol.flags: - m = "online, idle" + t = "online, idle" else: - m = "online (tags pending)" + t = "online (tags pending)" - self.volstate[vol.vpath] = m + self.volstate[vol.vpath] = t # open the rest + do any e2ts(a) needed_mutagen = False @@ -405,9 +435,9 @@ class Up2k(object): if "e2ts" not in vol.flags: continue - m = "online (reading tags)" - self.volstate[vol.vpath] = m - self.log("{} [{}]".format(m, vol.realpath)) + t = "online (reading tags)" + self.volstate[vol.vpath] = t + self.log("{} [{}]".format(t, vol.realpath)) nadd, nrm, success = self._build_tags_index(vol) if not success: @@ -419,7 +449,9 @@ class Up2k(object): self.volstate[vol.vpath] = "online (mtp soon)" for vol in need_vac: - cur, _ = self.register_vpath(vol.realpath, vol.flags) + reg = self.register_vpath(vol.realpath, vol.flags) + assert reg + cur, _ = reg with self.mutex: cur.connection.commit() cur.execute("vacuum") @@ -435,23 +467,25 @@ class Up2k(object): thr = None if self.mtag: - m = "online (running mtp)" + t = "online (running mtp)" if scan_vols: thr = threading.Thread(target=self._run_all_mtp, name="up2k-mtp-scan") thr.daemon = True else: - del self.pp - m = "online, idle" + self.pp = None + t = "online, idle" for vol in vols: - self.volstate[vol.vpath] = m + self.volstate[vol.vpath] = t if thr: thr.start() return have_e2d - def register_vpath(self, ptop, flags): + def register_vpath( + self, ptop: str, flags: dict[str, Any] + ) -> Optional[tuple["sqlite3.Cursor", str]]: histpath = self.asrv.vfs.histtab.get(ptop) if not histpath: self.log("no histpath for [{}]".format(ptop)) @@ -460,7 +494,7 @@ class Up2k(object): db_path = os.path.join(histpath, "up2k.db") if ptop in self.registry: try: - return [self.cur[ptop], db_path] + return self.cur[ptop], db_path except: return None @@ -514,14 +548,14 @@ class Up2k(object): else: drp = [x for x in drp if x in reg] - m = "loaded snap {} |{}| ({})".format(path, len(reg.keys()), len(drp or [])) - m = [m] + self._vis_reg_progress(reg) - self.log("\n".join(m)) + t = "loaded snap {} |{}| ({})".format(path, len(reg.keys()), len(drp or [])) + ta = [t] + self._vis_reg_progress(reg) + self.log("\n".join(ta)) self.flags[ptop] = flags self.registry[ptop] = reg self.droppable[ptop] = drp or [] - self.regdrop(ptop, None) + self.regdrop(ptop, "") if not HAVE_SQLITE3 or "e2d" not in flags or "d2d" in flags: return None @@ -530,23 +564,25 @@ class Up2k(object): try: cur = self._open_db(db_path) self.cur[ptop] = cur - return [cur, db_path] + return cur, db_path except: msg = "cannot use database at [{}]:\n{}" self.log(msg.format(ptop, traceback.format_exc())) return None - def _build_file_index(self, vol, all_vols): + def _build_file_index(self, vol: VFS, all_vols: list[VFS]) -> tuple[bool, bool]: do_vac = False top = vol.realpath rei = vol.flags.get("noidx") reh = vol.flags.get("nohash") with self.mutex: - cur, _ = self.register_vpath(top, vol.flags) + reg = self.register_vpath(top, vol.flags) + assert reg and self.pp + cur, _ = reg - dbw = [cur, 0, time.time()] - self.pp.n = next(dbw[0].execute("select count(w) from up"))[0] + db = Dbw(cur, 0, time.time()) + self.pp.n = next(db.c.execute("select count(w) from up"))[0] excl = [ vol.realpath + "/" + d.vpath[len(vol.vpath) :].lstrip("/") @@ -558,37 +594,47 @@ class Up2k(object): if WINDOWS: excl = [x.replace("/", "\\") for x in excl] - excl = set(excl) rtop = absreal(top) n_add = n_rm = 0 try: - n_add = self._build_dir(dbw, top, excl, top, rtop, rei, reh, []) - n_rm = self._drop_lost(dbw[0], top) + n_add = self._build_dir(db, top, set(excl), top, rtop, rei, reh, []) + n_rm = self._drop_lost(db.c, top) except: - m = "failed to index volume [{}]:\n{}" - self.log(m.format(top, min_ex()), c=1) + t = "failed to index volume [{}]:\n{}" + self.log(t.format(top, min_ex()), c=1) - if dbw[1]: - self.log("commit {} new files".format(dbw[1])) + if db.n: + self.log("commit {} new files".format(db.n)) - dbw[0].connection.commit() + db.c.connection.commit() - return True, n_add or n_rm or do_vac + return True, bool(n_add or n_rm or do_vac) - def _build_dir(self, dbw, top, excl, cdir, rcdir, rei, reh, seen): + def _build_dir( + self, + db: Dbw, + top: str, + excl: set[str], + cdir: str, + rcdir: str, + rei: Optional[Pattern[str]], + reh: Optional[Pattern[str]], + seen: list[str], + ) -> int: if rcdir in seen: - m = "bailing from symlink loop,\n prev: {}\n curr: {}\n from: {}" - self.log(m.format(seen[-1], rcdir, cdir), 3) + t = "bailing from symlink loop,\n prev: {}\n curr: {}\n from: {}" + self.log(t.format(seen[-1], rcdir, cdir), 3) return 0 seen = seen + [rcdir] + assert self.pp and self.mem_cur self.pp.msg = "a{} {}".format(self.pp.n, cdir) ret = 0 seen_files = {} # != inames; files-only for dropcheck g = statdir(self.log_func, not self.args.no_scandir, False, cdir) - g = sorted(g) - inames = {x[0]: 1 for x in g} - for iname, inf in g: + gl = sorted(g) + inames = {x[0]: 1 for x in gl} + for iname, inf in gl: abspath = os.path.join(cdir, iname) if rei and rei.search(abspath): continue @@ -605,10 +651,10 @@ class Up2k(object): continue # self.log(" dir: {}".format(abspath)) try: - ret += self._build_dir(dbw, top, excl, abspath, rap, rei, reh, seen) + ret += self._build_dir(db, top, excl, abspath, rap, rei, reh, seen) except: - m = "failed to index subdir [{}]:\n{}" - self.log(m.format(abspath, min_ex()), c=1) + t = "failed to index subdir [{}]:\n{}" + self.log(t.format(abspath, min_ex()), c=1) elif not stat.S_ISREG(inf.st_mode): self.log("skip type-{:x} file [{}]".format(inf.st_mode, abspath)) else: @@ -632,31 +678,31 @@ class Up2k(object): rd, fn = rp.rsplit("/", 1) if "/" in rp else ["", rp] sql = "select w, mt, sz from up where rd = ? and fn = ?" try: - c = dbw[0].execute(sql, (rd, fn)) + c = db.c.execute(sql, (rd, fn)) except: - c = dbw[0].execute(sql, s3enc(self.mem_cur, rd, fn)) + c = db.c.execute(sql, s3enc(self.mem_cur, rd, fn)) in_db = list(c.fetchall()) if in_db: self.pp.n -= 1 dw, dts, dsz = in_db[0] if len(in_db) > 1: - m = "WARN: multiple entries: [{}] => [{}] |{}|\n{}" + t = "WARN: multiple entries: [{}] => [{}] |{}|\n{}" rep_db = "\n".join([repr(x) for x in in_db]) - self.log(m.format(top, rp, len(in_db), rep_db)) + self.log(t.format(top, rp, len(in_db), rep_db)) dts = -1 if dts == lmod and dsz == sz and (nohash or dw[0] != "#"): continue - m = "reindex [{}] => [{}] ({}/{}) ({}/{})".format( + t = "reindex [{}] => [{}] ({}/{}) ({}/{})".format( top, rp, dts, lmod, dsz, sz ) - self.log(m) - self.db_rm(dbw[0], rd, fn) + self.log(t) + self.db_rm(db.c, rd, fn) ret += 1 - dbw[1] += 1 - in_db = None + db.n += 1 + in_db = [] self.pp.msg = "a{} {}".format(self.pp.n, abspath) @@ -674,15 +720,15 @@ class Up2k(object): wark = up2k_wark_from_hashlist(self.salt, sz, hashes) - self.db_add(dbw[0], wark, rd, fn, lmod, sz, "", 0) - dbw[1] += 1 + self.db_add(db.c, wark, rd, fn, lmod, sz, "", 0) + db.n += 1 ret += 1 - td = time.time() - dbw[2] - if dbw[1] >= 4096 or td >= 60: - self.log("commit {} new files".format(dbw[1])) - dbw[0].connection.commit() - dbw[1] = 0 - dbw[2] = time.time() + td = time.time() - db.t + if db.n >= 4096 or td >= 60: + self.log("commit {} new files".format(db.n)) + db.c.connection.commit() + db.n = 0 + db.t = time.time() # drop missing files rd = cdir[len(top) + 1 :].strip("/") @@ -691,25 +737,26 @@ class Up2k(object): q = "select fn from up where rd = ?" try: - c = dbw[0].execute(q, (rd,)) + c = db.c.execute(q, (rd,)) except: - c = dbw[0].execute(q, ("//" + w8b64enc(rd),)) + c = db.c.execute(q, ("//" + w8b64enc(rd),)) hits = [w8b64dec(x[2:]) if x.startswith("//") else x for (x,) in c] rm_files = [x for x in hits if x not in seen_files] n_rm = len(rm_files) for fn in rm_files: - self.db_rm(dbw[0], rd, fn) + self.db_rm(db.c, rd, fn) if n_rm: self.log("forgot {} deleted files".format(n_rm)) return ret - def _drop_lost(self, cur, top): + def _drop_lost(self, cur: "sqlite3.Cursor", top: str) -> int: rm = [] n_rm = 0 nchecked = 0 + assert self.pp # `_build_dir` did all the files, now do dirs ndirs = next(cur.execute("select count(distinct rd) from up"))[0] c = cur.execute("select distinct rd from up order by rd desc") @@ -743,13 +790,16 @@ class Up2k(object): return n_rm - def _build_tags_index(self, vol): + def _build_tags_index(self, vol: VFS) -> tuple[int, int, bool]: ptop = vol.realpath with self.mutex: - _, db_path = self.register_vpath(ptop, vol.flags) - entags = self.entags[ptop] - flags = self.flags[ptop] - cur = self.cur[ptop] + reg = self.register_vpath(ptop, vol.flags) + + assert reg and self.pp and self.mtag + _, db_path = reg + entags = self.entags[ptop] + flags = self.flags[ptop] + cur = self.cur[ptop] n_add = 0 n_rm = 0 @@ -794,7 +844,8 @@ class Up2k(object): if not self.mtag: return n_add, n_rm, False - mpool = False + mpool: Optional[Queue[Mpqe]] = None + if self.mtag.prefer_mt and self.args.mtag_mt > 1: mpool = self._start_mpool() @@ -819,11 +870,10 @@ class Up2k(object): abspath = os.path.join(ptop, rd, fn) self.pp.msg = "c{} {}".format(n_left, abspath) - args = [entags, w, abspath] if not mpool: - n_tags = self._tag_file(c3, *args) + n_tags = self._tag_file(c3, entags, w, abspath) else: - mpool.put(["mtag"] + args) + mpool.put(Mpqe({}, entags, w, abspath)) # not registry cursor; do not self.mutex: n_tags = len(self._flush_mpool(c3)) @@ -850,7 +900,7 @@ class Up2k(object): return n_add, n_rm, True - def _flush_mpool(self, wcur): + def _flush_mpool(self, wcur: "sqlite3.Cursor") -> list[str]: ret = [] for x in self.pending_tags: self._tag_file(wcur, *x) @@ -859,7 +909,7 @@ class Up2k(object): self.pending_tags = [] return ret - def _run_all_mtp(self): + def _run_all_mtp(self) -> None: gid = self.gid t0 = time.time() for ptop, flags in self.flags.items(): @@ -870,12 +920,12 @@ class Up2k(object): msg = "mtp finished in {:.2f} sec ({})" self.log(msg.format(td, s2hms(td, True))) - del self.pp + self.pp = None for k in list(self.volstate.keys()): if "OFFLINE" not in self.volstate[k]: self.volstate[k] = "online, idle" - def _run_one_mtp(self, ptop, gid): + def _run_one_mtp(self, ptop: str, gid: int) -> None: if gid != self.gid: return @@ -915,8 +965,8 @@ class Up2k(object): break q = "select w from mt where k = 't:mtp' limit ?" - warks = cur.execute(q, (batch_sz,)).fetchall() - warks = [x[0] for x in warks] + zq = cur.execute(q, (batch_sz,)).fetchall() + warks = [str(x[0]) for x in zq] jobs = [] for w in warks: q = "select rd, fn from up where substr(w,1,16)=? limit 1" @@ -925,8 +975,8 @@ class Up2k(object): abspath = os.path.join(ptop, rd, fn) q = "select k from mt where w = ?" - have = cur.execute(q, (w,)).fetchall() - have = [x[0] for x in have] + zq2 = cur.execute(q, (w,)).fetchall() + have: dict[str, Union[str, float]] = {x[0]: 1 for x in zq2} parsers = self._get_parsers(ptop, have, abspath) if not parsers: @@ -937,7 +987,7 @@ class Up2k(object): if w in in_progress: continue - jobs.append([parsers, None, w, abspath]) + jobs.append(Mpqe(parsers, set(), w, abspath)) in_progress[w] = True with self.mutex: @@ -997,7 +1047,9 @@ class Up2k(object): wcur.close() cur.close() - def _get_parsers(self, ptop, have, abspath): + def _get_parsers( + self, ptop: str, have: dict[str, Union[str, float]], abspath: str + ) -> dict[str, MParser]: try: all_parsers = self.mtp_parsers[ptop] except: @@ -1030,16 +1082,16 @@ class Up2k(object): parsers = {k: v for k, v in parsers.items() if v.force or k not in have} return parsers - def _start_mpool(self): + def _start_mpool(self) -> Queue[Mpqe]: # mp.pool.ThreadPool and concurrent.futures.ThreadPoolExecutor # both do crazy runahead so lets reinvent another wheel nw = max(1, self.args.mtag_mt) - - if self.pending_tags is None: + assert self.mtag + if not self.mpool_used: + self.mpool_used = True self.log("using {}x {}".format(nw, self.mtag.backend)) - self.pending_tags = [] - mpool = Queue(nw) + mpool: Queue[Mpqe] = Queue(nw) for _ in range(nw): thr = threading.Thread( target=self._tag_thr, args=(mpool,), name="up2k-mpool" @@ -1049,50 +1101,55 @@ class Up2k(object): return mpool - def _stop_mpool(self, mpool): + def _stop_mpool(self, mpool: Queue[Mpqe]) -> None: if not mpool: return for _ in range(mpool.maxsize): - mpool.put(None) + mpool.put(Mpqe({}, set(), "", "")) mpool.join() - def _tag_thr(self, q): + def _tag_thr(self, q: Queue[Mpqe]) -> None: + assert self.mtag while True: - task = q.get() - if not task: + qe = q.get() + if not qe.w: q.task_done() return try: - parser, entags, wark, abspath = task - if parser == "mtag": - tags = self.mtag.get(abspath) + if not qe.mtp: + tags = self.mtag.get(qe.abspath) else: - tags = self.mtag.get_bin(parser, abspath) + tags = self.mtag.get_bin(qe.mtp, qe.abspath) vtags = [ "\033[36m{} \033[33m{}".format(k, v) for k, v in tags.items() ] if vtags: - self.log("{}\033[0m [{}]".format(" ".join(vtags), abspath)) + self.log("{}\033[0m [{}]".format(" ".join(vtags), qe.abspath)) with self.mutex: - self.pending_tags.append([entags, wark, abspath, tags]) + self.pending_tags.append((qe.entags, qe.w, qe.abspath, tags)) except: ex = traceback.format_exc() - if parser == "mtag": - parser = self.mtag.backend - - self._log_tag_err(parser, abspath, ex) + self._log_tag_err(qe.mtp or self.mtag.backend, qe.abspath, ex) q.task_done() - def _log_tag_err(self, parser, abspath, ex): + def _log_tag_err(self, parser: Any, abspath: str, ex: Any) -> None: msg = "{} failed to read tags from {}:\n{}".format(parser, abspath, ex) self.log(msg.lstrip(), c=1 if " int: + assert self.mtag if tags is None: try: tags = self.mtag.get(abspath) @@ -1127,12 +1184,12 @@ class Up2k(object): return ret - def _orz(self, db_path): + def _orz(self, db_path: str) -> "sqlite3.Cursor": timeout = int(max(self.args.srch_time, 5) * 1.2) return sqlite3.connect(db_path, timeout, check_same_thread=False).cursor() # x.set_trace_callback(trace) - def _open_db(self, db_path): + def _open_db(self, db_path: str) -> "sqlite3.Cursor": existed = bos.path.exists(db_path) cur = self._orz(db_path) ver = self._read_ver(cur) @@ -1141,8 +1198,8 @@ class Up2k(object): if ver == 4: try: - m = "creating backup before upgrade: " - cur = self._backup_db(db_path, cur, ver, m) + t = "creating backup before upgrade: " + cur = self._backup_db(db_path, cur, ver, t) self._upgrade_v4(cur) ver = 5 except: @@ -1157,8 +1214,8 @@ class Up2k(object): self.log("WARN: could not list files; DB corrupt?\n" + min_ex()) if (ver or 0) > DB_VER: - m = "database is version {}, this copyparty only supports versions <= {}" - raise Exception(m.format(ver, DB_VER)) + t = "database is version {}, this copyparty only supports versions <= {}" + raise Exception(t.format(ver, DB_VER)) msg = "creating new DB (old is bad); backup: " if ver: @@ -1171,7 +1228,9 @@ class Up2k(object): bos.unlink(db_path) return self._create_db(db_path, None) - def _backup_db(self, db_path, cur, ver, msg): + def _backup_db( + self, db_path: str, cur: "sqlite3.Cursor", ver: Optional[int], msg: str + ) -> "sqlite3.Cursor": bak = "{}.bak.{:x}.v{}".format(db_path, int(time.time()), ver) self.log(msg + bak) try: @@ -1180,8 +1239,8 @@ class Up2k(object): cur.connection.backup(c2) return cur except: - m = "native sqlite3 backup failed; using fallback method:\n" - self.log(m + min_ex()) + t = "native sqlite3 backup failed; using fallback method:\n" + self.log(t + min_ex()) finally: c2.close() @@ -1192,7 +1251,7 @@ class Up2k(object): shutil.copy2(fsenc(db_path), fsenc(bak)) return self._orz(db_path) - def _read_ver(self, cur): + def _read_ver(self, cur: "sqlite3.Cursor") -> Optional[int]: for tab in ["ki", "kv"]: try: c = cur.execute(r"select v from {} where k = 'sver'".format(tab)) @@ -1202,8 +1261,11 @@ class Up2k(object): rows = c.fetchall() if rows: return int(rows[0][0]) + return None - def _create_db(self, db_path, cur): + def _create_db( + self, db_path: str, cur: Optional["sqlite3.Cursor"] + ) -> "sqlite3.Cursor": """ collision in 2^(n/2) files where n = bits (6 bits/ch) 10*6/2 = 2^30 = 1'073'741'824, 24.1mb idx 1<<(3*10) @@ -1236,7 +1298,7 @@ class Up2k(object): self.log("created DB at {}".format(db_path)) return cur - def _upgrade_v4(self, cur): + def _upgrade_v4(self, cur: "sqlite3.Cursor") -> None: for cmd in [ r"alter table up add column ip text", r"alter table up add column at int", @@ -1247,7 +1309,7 @@ class Up2k(object): cur.connection.commit() - def handle_json(self, cj): + def handle_json(self, cj: dict[str, Any]) -> dict[str, Any]: with self.mutex: if not self.register_vpath(cj["ptop"], cj["vcfg"]): if cj["ptop"] not in self.registry: @@ -1269,13 +1331,13 @@ class Up2k(object): if cur: if self.no_expr_idx: q = r"select * from up where w = ?" - argv = (wark,) + argv = [wark] else: q = r"select * from up where substr(w,1,16) = ? and w = ?" - argv = (wark[:16], wark) + argv = [wark[:16], wark] - alts = [] - cur = cur.execute(q, argv) + alts: list[tuple[int, int, dict[str, Any]]] = [] + cur = cur.execute(q, tuple(argv)) for _, dtime, dsize, dp_dir, dp_fn, ip, at in cur: if dp_dir.startswith("//") or dp_fn.startswith("//"): dp_dir, dp_fn = s3dec(dp_dir, dp_fn) @@ -1307,7 +1369,7 @@ class Up2k(object): + (2 if dp_dir == cj["prel"] else 0) + (1 if dp_fn == cj["name"] else 0) ) - alts.append([score, -len(alts), j]) + alts.append((score, -len(alts), j)) job = sorted(alts, reverse=True)[0][2] if alts else None if job and wark in reg: @@ -1344,7 +1406,7 @@ class Up2k(object): # registry is size-constrained + can only contain one unique wark; # let want_recheck trigger symlink (if still in reg) or reupload if cur: - dupe = [cj["prel"], cj["name"], cj["lmod"]] + dupe = (cj["prel"], cj["name"], cj["lmod"]) try: self.dupesched[src].append(dupe) except: @@ -1431,17 +1493,19 @@ class Up2k(object): "wark": wark, } - def _untaken(self, fdir, fname, ts, ip): + def _untaken(self, fdir: str, fname: str, ts: float, ip: str) -> str: if self.args.nw: return fname # TODO broker which avoid this race and # provides a new filename if taken (same as bup) suffix = "-{:.6f}-{}".format(ts, ip.replace(":", ".")) - with ren_open(fname, "wb", fdir=fdir, suffix=suffix) as f: - return f["orz"][1] + with ren_open(fname, "wb", fdir=fdir, suffix=suffix) as zfw: + return zfw["orz"][1] - def _symlink(self, src, dst, verbose=True, lmod=None): + def _symlink( + self, src: str, dst: str, verbose: bool = True, lmod: float = 0 + ) -> None: if verbose: self.log("linking dupe:\n {0}\n {1}".format(src, dst)) @@ -1475,9 +1539,9 @@ class Up2k(object): break nc += 1 if nc > 1: - lsrc = nsrc[nc:] + zsl = nsrc[nc:] hops = len(ndst[nc:]) - 1 - lsrc = "../" * hops + "/".join(lsrc) + lsrc = "../" * hops + "/".join(zsl) try: if self.args.hardlink: @@ -1498,11 +1562,13 @@ class Up2k(object): if lmod and (not linked or SYMTIME): times = (int(time.time()), int(lmod)) if ANYWIN: - self.lastmod_q.append([dst, 0, times]) + self.lastmod_q.append((dst, 0, times)) else: bos.utime(dst, times, False) - def handle_chunk(self, ptop, wark, chash): + def handle_chunk( + self, ptop: str, wark: str, chash: str + ) -> tuple[int, list[int], str, float]: with self.mutex: job = self.registry[ptop].get(wark) if not job: @@ -1523,8 +1589,8 @@ class Up2k(object): if chash in job["busy"]: nh = len(job["hash"]) idx = job["hash"].index(chash) - m = "that chunk is already being written to:\n {}\n {} {}/{}\n {}" - raise Pebkac(400, m.format(wark, chash, idx, nh, job["name"])) + t = "that chunk is already being written to:\n {}\n {} {}/{}\n {}" + raise Pebkac(400, t.format(wark, chash, idx, nh, job["name"])) job["busy"][chash] = 1 @@ -1535,17 +1601,17 @@ class Up2k(object): path = os.path.join(job["ptop"], job["prel"], job["tnam"]) - return [chunksize, ofs, path, job["lmod"]] + return chunksize, ofs, path, job["lmod"] - def release_chunk(self, ptop, wark, chash): + def release_chunk(self, ptop: str, wark: str, chash: str) -> bool: with self.mutex: job = self.registry[ptop].get(wark) if job: job["busy"].pop(chash, None) - return [True] + return True - def confirm_chunk(self, ptop, wark, chash): + def confirm_chunk(self, ptop: str, wark: str, chash: str) -> tuple[int, str]: with self.mutex: try: job = self.registry[ptop][wark] @@ -1553,14 +1619,14 @@ class Up2k(object): src = os.path.join(pdir, job["tnam"]) dst = os.path.join(pdir, job["name"]) except Exception as ex: - return "confirm_chunk, wark, " + repr(ex) + return "confirm_chunk, wark, " + repr(ex) # type: ignore job["busy"].pop(chash, None) try: job["need"].remove(chash) except Exception as ex: - return "confirm_chunk, chash, " + repr(ex) + return "confirm_chunk, chash, " + repr(ex) # type: ignore ret = len(job["need"]) if ret > 0: @@ -1576,35 +1642,35 @@ class Up2k(object): return ret, dst - def finish_upload(self, ptop, wark): + def finish_upload(self, ptop: str, wark: str) -> None: with self.mutex: self._finish_upload(ptop, wark) - def _finish_upload(self, ptop, wark): + def _finish_upload(self, ptop: str, wark: str) -> None: try: job = self.registry[ptop][wark] pdir = os.path.join(job["ptop"], job["prel"]) src = os.path.join(pdir, job["tnam"]) dst = os.path.join(pdir, job["name"]) except Exception as ex: - return "finish_upload, wark, " + repr(ex) + raise Pebkac(500, "finish_upload, wark, " + repr(ex)) # self.log("--- " + wark + " " + dst + " finish_upload atomic " + dst, 4) atomic_move(src, dst) times = (int(time.time()), int(job["lmod"])) if ANYWIN: - a = [dst, job["size"], times] - self.lastmod_q.append(a) + z1 = (dst, job["size"], times) + self.lastmod_q.append(z1) elif not job["hash"]: try: bos.utime(dst, times) except: pass - a = [job[x] for x in "ptop wark prel name lmod size addr".split()] - a += [job.get("at") or time.time()] - if self.idx_wark(*a): + z2 = [job[x] for x in "ptop wark prel name lmod size addr".split()] + z2 += [job.get("at") or time.time()] + if self.idx_wark(*z2): del self.registry[ptop][wark] else: self.regdrop(ptop, wark) @@ -1622,27 +1688,37 @@ class Up2k(object): self._symlink(dst, d2, lmod=lmod) if cur: self.db_rm(cur, rd, fn) - self.db_add(cur, wark, rd, fn, *a[-4:]) + self.db_add(cur, wark, rd, fn, *z2[-4:]) if cur: cur.connection.commit() - def regdrop(self, ptop, wark): - t = self.droppable[ptop] + def regdrop(self, ptop: str, wark: str) -> None: + olds = self.droppable[ptop] if wark: - t.append(wark) + olds.append(wark) - if len(t) <= self.args.reg_cap: + if len(olds) <= self.args.reg_cap: return - n = len(t) - int(self.args.reg_cap / 2) - m = "up2k-registry [{}] has {} droppables; discarding {}" - self.log(m.format(ptop, len(t), n)) - for k in t[:n]: + n = len(olds) - int(self.args.reg_cap / 2) + t = "up2k-registry [{}] has {} droppables; discarding {}" + self.log(t.format(ptop, len(olds), n)) + for k in olds[:n]: self.registry[ptop].pop(k, None) - self.droppable[ptop] = t[n:] + self.droppable[ptop] = olds[n:] - def idx_wark(self, ptop, wark, rd, fn, lmod, sz, ip, at): + def idx_wark( + self, + ptop: str, + wark: str, + rd: str, + fn: str, + lmod: float, + sz: int, + ip: str, + at: float, + ) -> bool: cur = self.cur.get(ptop) if not cur: return False @@ -1652,29 +1728,41 @@ class Up2k(object): cur.connection.commit() if "e2t" in self.flags[ptop]: - self.tagq.put([ptop, wark, rd, fn]) + self.tagq.put((ptop, wark, rd, fn)) self.n_tagq += 1 return True - def db_rm(self, db, rd, fn): + def db_rm(self, db: "sqlite3.Cursor", rd: str, fn: str) -> None: sql = "delete from up where rd = ? and fn = ?" try: db.execute(sql, (rd, fn)) except: + assert self.mem_cur db.execute(sql, s3enc(self.mem_cur, rd, fn)) - def db_add(self, db, wark, rd, fn, ts, sz, ip, at): + def db_add( + self, + db: "sqlite3.Cursor", + wark: str, + rd: str, + fn: str, + ts: float, + sz: int, + ip: str, + at: float, + ) -> None: sql = "insert into up values (?,?,?,?,?,?,?)" v = (wark, int(ts), sz, rd, fn, ip or "", int(at or 0)) try: db.execute(sql, v) except: + assert self.mem_cur rd, fn = s3enc(self.mem_cur, rd, fn) v = (wark, int(ts), sz, rd, fn, ip or "", int(at or 0)) db.execute(sql, v) - def handle_rm(self, uname, ip, vpaths): + def handle_rm(self, uname: str, ip: str, vpaths: list[str]) -> str: n_files = 0 ok = {} ng = {} @@ -1687,12 +1775,14 @@ class Up2k(object): ng[k] = 1 ng = {k: 1 for k in ng if k not in ok} - ok = len(ok) - ng = len(ng) + iok = len(ok) + ing = len(ng) - return "deleted {} files (and {}/{} folders)".format(n_files, ok, ok + ng) + return "deleted {} files (and {}/{} folders)".format(n_files, iok, iok + ing) - def _handle_rm(self, uname, ip, vpath): + def _handle_rm( + self, uname: str, ip: str, vpath: str + ) -> tuple[int, list[str], list[str]]: try: permsets = [[True, False, False, True]] vn, rem = self.asrv.vfs.get(vpath, uname, *permsets[0]) @@ -1709,18 +1799,18 @@ class Up2k(object): vn, rem = vn.get_dbv(rem) _, _, _, _, dip, dat = self._find_from_vpath(vn.realpath, rem) - m = "you cannot delete this: " + t = "you cannot delete this: " if not dip: - m += "file not found" + t += "file not found" elif dip != ip: - m += "not uploaded by (You)" + t += "not uploaded by (You)" elif dat < time.time() - self.args.unpost: - m += "uploaded too long ago" + t += "uploaded too long ago" else: - m = None + t = "" - if m: - raise Pebkac(400, m) + if t: + raise Pebkac(400, t) ptop = vn.realpath atop = vn.canonical(rem, False) @@ -1731,16 +1821,19 @@ class Up2k(object): raise Pebkac(400, "file not found on disk (already deleted?)") scandir = not self.args.no_scandir - if stat.S_ISLNK(st.st_mode) or stat.S_ISREG(st.st_mode): + if stat.S_ISDIR(st.st_mode): + g = vn.walk("", rem, [], uname, permsets, True, scandir, True) + if unpost: + raise Pebkac(400, "cannot unpost folders") + elif stat.S_ISLNK(st.st_mode) or stat.S_ISREG(st.st_mode): dbv, vrem = self.asrv.vfs.get(vpath, uname, *permsets[0]) dbv, vrem = dbv.get_dbv(vrem) voldir = vsplit(vrem)[0] vpath_dir = vsplit(vpath)[0] - g = [[dbv, voldir, vpath_dir, adir, [[fn, 0]], [], []]] + g = [(dbv, voldir, vpath_dir, adir, [(fn, 0)], [], {})] # type: ignore else: - g = vn.walk("", rem, [], uname, permsets, True, scandir, True) - if unpost: - raise Pebkac(400, "cannot unpost folders") + self.log("rm: skip type-{:x} file [{}]".format(st.st_mode, atop)) + return 0, [], [] n_files = 0 for dbv, vrem, _, adir, files, rd, vd in g: @@ -1766,7 +1859,7 @@ class Up2k(object): rm = rmdirs(self.log_func, scandir, True, atop, 1) return n_files, rm[0], rm[1] - def handle_mv(self, uname, svp, dvp): + def handle_mv(self, uname: str, svp: str, dvp: str) -> str: svn, srem = self.asrv.vfs.get(svp, uname, True, False, True) svn, srem = svn.get_dbv(srem) sabs = svn.canonical(srem, False) @@ -1808,7 +1901,7 @@ class Up2k(object): rmdirs(self.log_func, scandir, True, sabs, 1) return "k" - def _mv_file(self, uname, svp, dvp): + def _mv_file(self, uname: str, svp: str, dvp: str) -> str: svn, srem = self.asrv.vfs.get(svp, uname, True, False, True) svn, srem = svn.get_dbv(srem) @@ -1834,29 +1927,33 @@ class Up2k(object): if bos.path.islink(sabs): dlabs = absreal(sabs) - m = "moving symlink from [{}] to [{}], target [{}]" - self.log(m.format(sabs, dabs, dlabs)) + t = "moving symlink from [{}] to [{}], target [{}]" + self.log(t.format(sabs, dabs, dlabs)) mt = bos.path.getmtime(sabs, False) bos.unlink(sabs) self._symlink(dlabs, dabs, False, lmod=mt) # folders are too scary, schedule rescan of both vols - self.need_rescan[svn.vpath] = 1 - self.need_rescan[dvn.vpath] = 1 + self.need_rescan.add(svn.vpath) + self.need_rescan.add(dvn.vpath) with self.rescan_cond: self.rescan_cond.notify_all() return "k" - c1, w, ftime, fsize, ip, at = self._find_from_vpath(svn.realpath, srem) + c1, w, ftime_, fsize_, ip, at = self._find_from_vpath(svn.realpath, srem) c2 = self.cur.get(dvn.realpath) - if ftime is None: + if ftime_ is None: st = bos.stat(sabs) ftime = st.st_mtime fsize = st.st_size + else: + ftime = ftime_ + fsize = fsize_ or 0 if w: + assert c1 if c2 and c2 != c1: self._copy_tags(c1, c2, w) @@ -1865,7 +1962,7 @@ class Up2k(object): c1.connection.commit() if c2: - self.db_add(c2, w, drd, dfn, ftime, fsize, ip, at) + self.db_add(c2, w, drd, dfn, ftime, fsize, ip or "", at or 0) c2.connection.commit() else: self.log("not found in src db: [{}]".format(svp)) @@ -1873,7 +1970,9 @@ class Up2k(object): bos.rename(sabs, dabs) return "k" - def _copy_tags(self, csrc, cdst, wark): + def _copy_tags( + self, csrc: "sqlite3.Cursor", cdst: "sqlite3.Cursor", wark: str + ) -> None: """copy all tags for wark from src-db to dst-db""" w = wark[:16] @@ -1883,16 +1982,26 @@ class Up2k(object): for _, k, v in csrc.execute("select * from mt where w=?", (w,)): cdst.execute("insert into mt values(?,?,?)", (w, k, v)) - def _find_from_vpath(self, ptop, vrem): + def _find_from_vpath( + self, ptop: str, vrem: str + ) -> tuple[ + Optional["sqlite3.Cursor"], + Optional[str], + Optional[int], + Optional[int], + Optional[str], + Optional[int], + ]: cur = self.cur.get(ptop) if not cur: - return [None] * 6 + return None, None, None, None, None, None rd, fn = vsplit(vrem) q = "select w, mt, sz, ip, at from up where rd=? and fn=? limit 1" try: c = cur.execute(q, (rd, fn)) except: + assert self.mem_cur c = cur.execute(q, s3enc(self.mem_cur, rd, fn)) hit = c.fetchone() @@ -1901,14 +2010,21 @@ class Up2k(object): return cur, wark, ftime, fsize, ip, at return cur, None, None, None, None, None - def _forget_file(self, ptop, vrem, cur, wark, drop_tags): + def _forget_file( + self, + ptop: str, + vrem: str, + cur: Optional["sqlite3.Cursor"], + wark: Optional[str], + drop_tags: bool, + ) -> None: """forgets file in db, fixes symlinks, does not delete""" srd, sfn = vsplit(vrem) self.log("forgetting {}".format(vrem)) - if wark: + if wark and cur: self.log("found {} in db".format(wark)) if drop_tags: - if self._relink(wark, ptop, vrem, None): + if self._relink(wark, ptop, vrem, ""): drop_tags = False if drop_tags: @@ -1919,20 +2035,24 @@ class Up2k(object): reg = self.registry.get(ptop) if reg: - if not wark: - wark = [ + vdir = vsplit(vrem)[0] + wark = wark or next( + ( x for x, y in reg.items() - if sfn in [y["name"], y.get("tnam")] and y["prel"] == vrem - ] - - if wark and wark in reg: - m = "forgetting partial upload {} ({})" - p = self._vis_job_progress(wark) - self.log(m.format(wark, p)) + if sfn in [y["name"], y.get("tnam")] and y["prel"] == vdir + ), + "", + ) + job = reg.get(wark) if wark else None + if job: + t = "forgetting partial upload {} ({})" + p = self._vis_job_progress(job) + self.log(t.format(wark, p)) + assert wark del reg[wark] - def _relink(self, wark, sptop, srem, dabs): + def _relink(self, wark: str, sptop: str, srem: str, dabs: str) -> int: """ update symlinks from file at svn/srem to dabs (rename), or to first remaining full if no dabs (delete) @@ -1953,13 +2073,13 @@ class Up2k(object): if not dupes: return 0 - full = {} - links = {} + full: dict[str, tuple[str, str]] = {} + links: dict[str, tuple[str, str]] = {} for ptop, vp in dupes: ap = os.path.join(ptop, vp) try: d = links if bos.path.islink(ap) else full - d[ap] = [ptop, vp] + d[ap] = (ptop, vp) except: self.log("relink: not found: [{}]".format(ap)) @@ -1973,13 +2093,13 @@ class Up2k(object): bos.rename(sabs, slabs) bos.utime(slabs, (int(time.time()), int(mt)), False) self._symlink(slabs, sabs, False) - full[slabs] = [ptop, rem] + full[slabs] = (ptop, rem) sabs = slabs if not dabs: dabs = list(sorted(full.keys()))[0] - for alink in links.keys(): + for alink in links: lmod = None try: if alink != sabs and absreal(alink) != sabs: @@ -1991,11 +2111,11 @@ class Up2k(object): except: pass - self._symlink(dabs, alink, False, lmod=lmod) + self._symlink(dabs, alink, False, lmod=lmod or 0) return len(full) + len(links) - def _get_wark(self, cj): + def _get_wark(self, cj: dict[str, Any]) -> str: if len(cj["name"]) > 1024 or len(cj["hash"]) > 512 * 1024: # 16TiB raise Pebkac(400, "name or numchunks not according to spec") @@ -2020,15 +2140,14 @@ class Up2k(object): return wark - def _hashlist_from_file(self, path): - pp = self.pp if hasattr(self, "pp") else None + def _hashlist_from_file(self, path: str) -> list[str]: fsz = bos.path.getsize(path) csz = up2k_chunksize(fsz) ret = [] with open(fsenc(path), "rb", 512 * 1024) as f: while fsz > 0: - if pp: - pp.msg = "{} MB, {}".format(int(fsz / 1024 / 1024), path) + if self.pp: + self.pp.msg = "{} MB, {}".format(int(fsz / 1024 / 1024), path) hashobj = hashlib.sha512() rem = min(csz, fsz) @@ -2047,7 +2166,7 @@ class Up2k(object): return ret - def _new_upload(self, job): + def _new_upload(self, job: dict[str, Any]) -> None: self.registry[job["ptop"]][job["wark"]] = job pdir = os.path.join(job["ptop"], job["prel"]) job["name"] = self._untaken(pdir, job["name"], job["t0"], job["addr"]) @@ -2066,8 +2185,8 @@ class Up2k(object): dip = job["addr"].replace(":", ".") suffix = "-{:.6f}-{}".format(job["t0"], dip) - with ren_open(tnam, "wb", fdir=pdir, suffix=suffix) as f: - f, job["tnam"] = f["orz"] + with ren_open(tnam, "wb", fdir=pdir, suffix=suffix) as zfw: + f, job["tnam"] = zfw["orz"] if ( ANYWIN and self.args.sparse @@ -2086,7 +2205,7 @@ class Up2k(object): if not job["hash"]: self._finish_upload(job["ptop"], job["wark"]) - def _lastmodder(self): + def _lastmodder(self) -> None: while True: ready = self.lastmod_q self.lastmod_q = [] @@ -2098,8 +2217,8 @@ class Up2k(object): try: bos.utime(path, times, False) except: - m = "lmod: failed to utime ({}, {}):\n{}" - self.log(m.format(path, times, min_ex())) + t = "lmod: failed to utime ({}, {}):\n{}" + self.log(t.format(path, times, min_ex())) if self.args.sparse and self.args.sparse * 1024 * 1024 <= sz: try: @@ -2107,21 +2226,22 @@ class Up2k(object): except: self.log("could not unsparse [{}]".format(path), 3) - def _snapshot(self): - self.snap_persist_interval = 300 # persist unfinished index every 5 min - self.snap_discard_interval = 21600 # drop unfinished after 6 hours inactivity - self.snap_prev = {} + def _snapshot(self) -> None: + slp = self.snap_persist_interval while True: - time.sleep(self.snap_persist_interval) - if not hasattr(self, "pp"): + time.sleep(slp) + if self.pp: + slp = 5 + else: + slp = self.snap_persist_interval self.do_snapshot() - def do_snapshot(self): + def do_snapshot(self) -> None: with self.mutex: for k, reg in self.registry.items(): self._snap_reg(k, reg) - def _snap_reg(self, ptop, reg): + def _snap_reg(self, ptop: str, reg: dict[str, dict[str, Any]]) -> None: now = time.time() histpath = self.asrv.vfs.histtab.get(ptop) if not histpath: @@ -2133,9 +2253,9 @@ class Up2k(object): if x["need"] and now - x["poke"] > self.snap_discard_interval ] if rm: - m = "dropping {} abandoned uploads in {}".format(len(rm), ptop) + t = "dropping {} abandoned uploads in {}".format(len(rm), ptop) vis = [self._vis_job_progress(x) for x in rm] - self.log("\n".join([m] + vis)) + self.log("\n".join([t] + vis)) for job in rm: del reg[job["wark"]] try: @@ -2159,8 +2279,8 @@ class Up2k(object): bos.unlink(path) return - newest = max(x["poke"] for _, x in reg.items()) if reg else 0 - etag = [len(reg), newest] + newest = float(max(x["poke"] for _, x in reg.items()) if reg else 0) + etag = (len(reg), newest) if etag == self.snap_prev.get(ptop): return @@ -2177,10 +2297,11 @@ class Up2k(object): self.log("snap: {} |{}|".format(path, len(reg.keys()))) self.snap_prev[ptop] = etag - def _tagger(self): + def _tagger(self) -> None: with self.mutex: self.n_tagq += 1 + assert self.mtag while True: with self.mutex: self.n_tagq -= 1 @@ -2218,7 +2339,7 @@ class Up2k(object): self.log("tagged {} ({}+{})".format(abspath, ntags1, len(tags) - ntags1)) - def _hasher(self): + def _hasher(self) -> None: with self.mutex: self.n_hashq += 1 @@ -2240,20 +2361,21 @@ class Up2k(object): with self.mutex: self.idx_wark(ptop, wark, rd, fn, inf.st_mtime, inf.st_size, ip, at) - def hash_file(self, ptop, flags, rd, fn, ip, at): + def hash_file( + self, ptop: str, flags: dict[str, Any], rd: str, fn: str, ip: str, at: float + ) -> None: with self.mutex: self.register_vpath(ptop, flags) - self.hashq.put([ptop, rd, fn, ip, at]) + self.hashq.put((ptop, rd, fn, ip, at)) self.n_hashq += 1 # self.log("hashq {} push {}/{}/{}".format(self.n_hashq, ptop, rd, fn)) - def shutdown(self): - if hasattr(self, "snap_prev"): - self.log("writing snapshot") - self.do_snapshot() + def shutdown(self) -> None: + self.log("writing snapshot") + self.do_snapshot() -def up2k_chunksize(filesize): +def up2k_chunksize(filesize: int) -> int: chunksize = 1024 * 1024 stepsize = 512 * 1024 while True: @@ -2266,18 +2388,17 @@ def up2k_chunksize(filesize): stepsize *= mul -def up2k_wark_from_hashlist(salt, filesize, hashes): +def up2k_wark_from_hashlist(salt: str, filesize: int, hashes: list[str]) -> str: """server-reproducible file identifier, independent of name or location""" - ident = [salt, str(filesize)] - ident.extend(hashes) - ident = "\n".join(ident) + values = [salt, str(filesize)] + hashes + vstr = "\n".join(values) - wark = hashlib.sha512(ident.encode("utf-8")).digest()[:33] + wark = hashlib.sha512(vstr.encode("utf-8")).digest()[:33] wark = base64.urlsafe_b64encode(wark) return wark.decode("ascii") -def up2k_wark_from_metadata(salt, sz, lastmod, rd, fn): +def up2k_wark_from_metadata(salt: str, sz: int, lastmod: int, rd: str, fn: str) -> str: ret = fsenc("{}\n{}\n{}\n{}\n{}".format(salt, lastmod, sz, rd, fn)) ret = base64.urlsafe_b64encode(hashlib.sha512(ret).digest()) return "#{}".format(ret.decode("ascii"))[:44] diff --git a/copyparty/util.py b/copyparty/util.py index a01fdfdb..2ca710f5 100644 --- a/copyparty/util.py +++ b/copyparty/util.py @@ -1,51 +1,79 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import re -import os -import sys -import stat -import time import base64 +import contextlib +import hashlib +import mimetypes +import os +import platform +import re import select -import struct import signal import socket -import hashlib -import platform -import traceback -import threading -import mimetypes -import contextlib +import stat +import struct import subprocess as sp # nosec -from datetime import datetime +import sys +import threading +import time +import traceback from collections import Counter +from datetime import datetime -from .__init__ import PY2, WINDOWS, ANYWIN, VT100 +from .__init__ import ANYWIN, PY2, TYPE_CHECKING, VT100, WINDOWS from .stolen import surrogateescape +try: + HAVE_SQLITE3 = True + import sqlite3 # pylint: disable=unused-import # typechk +except: + HAVE_SQLITE3 = False + +try: + import types + from collections.abc import Callable, Iterable + + import typing + from typing import Any, Generator, Optional, Protocol, Union + + class RootLogger(Protocol): + def __call__(self, src: str, msg: str, c: Union[int, str] = 0) -> None: + return None + + class NamedLogger(Protocol): + def __call__(self, msg: str, c: Union[int, str] = 0) -> None: + return None + +except: + pass + +if TYPE_CHECKING: + from .authsrv import VFS + + FAKE_MP = False try: - if FAKE_MP: - import multiprocessing.dummy as mp # noqa: F401 # pylint: disable=unused-import + if not FAKE_MP: + import multiprocessing as mp else: - import multiprocessing as mp # noqa: F401 # pylint: disable=unused-import + import multiprocessing.dummy as mp # type: ignore except ImportError: # support jython - mp = None + mp = None # type: ignore if not PY2: - from urllib.parse import unquote_to_bytes as unquote + from io import BytesIO from urllib.parse import quote_from_bytes as quote - from queue import Queue # pylint: disable=unused-import - from io import BytesIO # pylint: disable=unused-import + from urllib.parse import unquote_to_bytes as unquote else: - from urllib import unquote # pylint: disable=no-name-in-module + from StringIO import StringIO as BytesIO from urllib import quote # pylint: disable=no-name-in-module - from Queue import Queue # pylint: disable=import-error,no-name-in-module - from StringIO import StringIO as BytesIO # pylint: disable=unused-import + from urllib import unquote # pylint: disable=no-name-in-module +_: Any = (mp, BytesIO, quote, unquote) +__all__ = ["mp", "BytesIO", "quote", "unquote"] try: struct.unpack(b">i", b"idgi") @@ -53,20 +81,21 @@ try: sunpack = struct.unpack except: - def spack(f, *a, **ka): - return struct.pack(f.decode("ascii"), *a, **ka) + def spack(fmt: bytes, *a: Any) -> bytes: + return struct.pack(fmt.decode("ascii"), *a) - def sunpack(f, *a, **ka): - return struct.unpack(f.decode("ascii"), *a, **ka) + def sunpack(fmt: bytes, a: bytes) -> tuple[Any, ...]: + return struct.unpack(fmt.decode("ascii"), a) ansi_re = re.compile("\033\\[[^mK]*[mK]") surrogateescape.register_surrogateescape() -FS_ENCODING = sys.getfilesystemencoding() if WINDOWS and PY2: FS_ENCODING = "utf-8" +else: + FS_ENCODING = sys.getfilesystemencoding() SYMTIME = sys.version_info >= (3, 6) and os.utime in os.supports_follow_symlinks @@ -116,7 +145,7 @@ MIMES = { } -def _add_mimes(): +def _add_mimes() -> None: for ln in """text css html csv application json wasm xml pdf rtf zip image webp jpeg png gif bmp @@ -170,18 +199,18 @@ REKOBO_LKEY = {k.lower(): v for k, v in REKOBO_KEY.items()} class Cooldown(object): - def __init__(self, maxage): + def __init__(self, maxage: float) -> None: self.maxage = maxage self.mutex = threading.Lock() - self.hist = {} - self.oldest = 0 + self.hist: dict[str, float] = {} + self.oldest = 0.0 - def poke(self, key): + def poke(self, key: str) -> bool: with self.mutex: now = time.time() ret = False - pv = self.hist.get(key, 0) + pv: float = self.hist.get(key, 0) if now - pv > self.maxage: self.hist[key] = now ret = True @@ -204,12 +233,12 @@ class _Unrecv(object): undo any number of socket recv ops """ - def __init__(self, s, log): - self.s = s # type: socket.socket + def __init__(self, s: socket.socket, log: Optional[NamedLogger]) -> None: + self.s = s self.log = log - self.buf = b"" + self.buf: bytes = b"" - def recv(self, nbytes): + def recv(self, nbytes: int) -> bytes: if self.buf: ret = self.buf[:nbytes] self.buf = self.buf[nbytes:] @@ -221,25 +250,25 @@ class _Unrecv(object): return ret - def recv_ex(self, nbytes, raise_on_trunc=True): + def recv_ex(self, nbytes: int, raise_on_trunc: bool = True) -> bytes: """read an exact number of bytes""" ret = b"" try: while nbytes > len(ret): ret += self.recv(nbytes - len(ret)) except OSError: - m = "client only sent {} of {} expected bytes".format(len(ret), nbytes) + t = "client only sent {} of {} expected bytes".format(len(ret), nbytes) if len(ret) <= 16: - m += "; got {!r}".format(ret) + t += "; got {!r}".format(ret) if raise_on_trunc: - raise UnrecvEOF(5, m) + raise UnrecvEOF(5, t) elif self.log: - self.log(m, 3) + self.log(t, 3) return ret - def unrecv(self, buf): + def unrecv(self, buf: bytes) -> None: self.buf = buf + self.buf @@ -248,28 +277,28 @@ class _LUnrecv(object): with expensive debug logging """ - def __init__(self, s, log): + def __init__(self, s: socket.socket, log: Optional[NamedLogger]) -> None: self.s = s self.log = log self.buf = b"" - def recv(self, nbytes): + def recv(self, nbytes: int) -> bytes: if self.buf: ret = self.buf[:nbytes] self.buf = self.buf[nbytes:] - m = "\033[0;7mur:pop:\033[0;1;32m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m" - self.log(m.format(ret, self.buf)) + t = "\033[0;7mur:pop:\033[0;1;32m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m" + print(t.format(ret, self.buf)) return ret ret = self.s.recv(nbytes) - m = "\033[0;7mur:recv\033[0;1;33m {}\033[0m" - self.log(m.format(ret)) + t = "\033[0;7mur:recv\033[0;1;33m {}\033[0m" + print(t.format(ret)) if not ret: raise UnrecvEOF("client stopped sending data") return ret - def recv_ex(self, nbytes, raise_on_trunc=True): + def recv_ex(self, nbytes: int, raise_on_trunc: bool = True) -> bytes: """read an exact number of bytes""" try: ret = self.recv(nbytes) @@ -285,18 +314,18 @@ class _LUnrecv(object): err = True if err: - m = "client only sent {} of {} expected bytes".format(len(ret), nbytes) + t = "client only sent {} of {} expected bytes".format(len(ret), nbytes) if raise_on_trunc: - raise UnrecvEOF(m) + raise UnrecvEOF(t) elif self.log: - self.log(m, 3) + self.log(t, 3) return ret - def unrecv(self, buf): + def unrecv(self, buf: bytes) -> None: self.buf = buf + self.buf - m = "\033[0;7mur:push\033[0;1;31m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m" - self.log(m.format(buf, self.buf)) + t = "\033[0;7mur:push\033[0;1;31m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m" + print(t.format(buf, self.buf)) Unrecv = _Unrecv @@ -304,14 +333,14 @@ Unrecv = _Unrecv class FHC(object): class CE(object): - def __init__(self, fh): - self.ts = 0 + def __init__(self, fh: typing.BinaryIO) -> None: + self.ts: float = 0 self.fhs = [fh] - def __init__(self): - self.cache = {} + def __init__(self) -> None: + self.cache: dict[str, FHC.CE] = {} - def close(self, path): + def close(self, path: str) -> None: try: ce = self.cache[path] except: @@ -322,7 +351,7 @@ class FHC(object): del self.cache[path] - def clean(self): + def clean(self) -> None: if not self.cache: return @@ -337,10 +366,10 @@ class FHC(object): self.cache = keep - def pop(self, path): + def pop(self, path: str) -> typing.BinaryIO: return self.cache[path].fhs.pop() - def put(self, path, fh): + def put(self, path: str, fh: typing.BinaryIO) -> None: try: ce = self.cache[path] ce.fhs.append(fh) @@ -356,14 +385,15 @@ class ProgressPrinter(threading.Thread): periodically print progress info without linefeeds """ - def __init__(self): + def __init__(self) -> None: threading.Thread.__init__(self, name="pp") self.daemon = True - self.msg = None + self.msg = "" self.end = False + self.n = -1 self.start() - def run(self): + def run(self) -> None: msg = None fmt = " {}\033[K\r" if VT100 else " {} $\r" while not self.end: @@ -384,7 +414,7 @@ class ProgressPrinter(threading.Thread): sys.stdout.flush() # necessary on win10 even w/ stderr btw -def uprint(msg): +def uprint(msg: str) -> None: try: print(msg, end="") except UnicodeEncodeError: @@ -394,17 +424,17 @@ def uprint(msg): print(msg.encode("ascii", "replace").decode(), end="") -def nuprint(msg): +def nuprint(msg: str) -> None: uprint("{}\n".format(msg)) -def rice_tid(): +def rice_tid() -> str: tid = threading.current_thread().ident c = sunpack(b"B" * 5, spack(b">Q", tid)[-5:]) return "".join("\033[1;37;48;5;{0}m{0:02x}".format(x) for x in c) + "\033[0m" -def trace(*args, **kwargs): +def trace(*args: Any, **kwargs: Any) -> None: t = time.time() stack = "".join( "\033[36m{}\033[33m{}".format(x[0].split(os.sep)[-1][:-3], x[1]) @@ -423,15 +453,15 @@ def trace(*args, **kwargs): nuprint(msg) -def alltrace(): - threads = {} +def alltrace() -> str: + threads: dict[str, types.FrameType] = {} names = dict([(t.ident, t.name) for t in threading.enumerate()]) for tid, stack in sys._current_frames().items(): name = "{} ({:x})".format(names.get(tid), tid) threads[name] = stack - rret = [] - bret = [] + rret: list[str] = [] + bret: list[str] = [] for name, stack in sorted(threads.items()): ret = ["\n\n# {}".format(name)] pad = None @@ -451,20 +481,20 @@ def alltrace(): return "\n".join(rret + bret) -def start_stackmon(arg_str, nid): +def start_stackmon(arg_str: str, nid: int) -> None: suffix = "-{}".format(nid) if nid else "" fp, f = arg_str.rsplit(",", 1) - f = int(f) + zi = int(f) t = threading.Thread( target=stackmon, - args=(fp, f, suffix), + args=(fp, zi, suffix), name="stackmon" + suffix, ) t.daemon = True t.start() -def stackmon(fp, ival, suffix): +def stackmon(fp: str, ival: float, suffix: str) -> None: ctr = 0 while True: ctr += 1 @@ -474,7 +504,9 @@ def stackmon(fp, ival, suffix): f.write(st.encode("utf-8", "replace")) -def start_log_thrs(logger, ival, nid): +def start_log_thrs( + logger: Callable[[str, str, int], None], ival: float, nid: int +) -> None: ival = float(ival) tname = lname = "log-thrs" if nid: @@ -490,7 +522,7 @@ def start_log_thrs(logger, ival, nid): t.start() -def log_thrs(log, ival, name): +def log_thrs(log: Callable[[str, str, int], None], ival: float, name: str) -> None: while True: time.sleep(ival) tv = [x.name for x in threading.enumerate()] @@ -507,7 +539,7 @@ def log_thrs(log, ival, name): log(name, "\033[0m \033[33m".join(tv), 3) -def vol_san(vols, txt): +def vol_san(vols: list["VFS"], txt: bytes) -> bytes: for vol in vols: txt = txt.replace(vol.realpath.encode("utf-8"), vol.vpath.encode("utf-8")) txt = txt.replace( @@ -518,24 +550,26 @@ def vol_san(vols, txt): return txt -def min_ex(max_lines=8, reverse=False): +def min_ex(max_lines: int = 8, reverse: bool = False) -> str: et, ev, tb = sys.exc_info() - tb = traceback.extract_tb(tb) + stb = traceback.extract_tb(tb) fmt = "{} @ {} <{}>: {}" - ex = [fmt.format(fp.split(os.sep)[-1], ln, fun, txt) for fp, ln, fun, txt in tb] - ex.append("[{}] {}".format(et.__name__, ev)) + ex = [fmt.format(fp.split(os.sep)[-1], ln, fun, txt) for fp, ln, fun, txt in stb] + ex.append("[{}] {}".format(et.__name__ if et else "(anonymous)", ev)) return "\n".join(ex[-max_lines:][:: -1 if reverse else 1]) @contextlib.contextmanager -def ren_open(fname, *args, **kwargs): +def ren_open( + fname: str, *args: Any, **kwargs: Any +) -> Generator[dict[str, tuple[typing.IO[Any], str]], None, None]: fun = kwargs.pop("fun", open) fdir = kwargs.pop("fdir", None) suffix = kwargs.pop("suffix", None) if fname == os.devnull: with fun(fname, *args, **kwargs) as f: - yield {"orz": [f, fname]} + yield {"orz": (f, fname)} return if suffix: @@ -575,7 +609,7 @@ def ren_open(fname, *args, **kwargs): with open(fsenc(fp2), "wb") as f2: f2.write(orig_name.encode("utf-8")) - yield {"orz": [f, fname]} + yield {"orz": (f, fname)} return except OSError as ex_: @@ -584,9 +618,9 @@ def ren_open(fname, *args, **kwargs): raise if not b64: - b64 = (bname + ext).encode("utf-8", "replace") - b64 = hashlib.sha512(b64).digest()[:12] - b64 = base64.urlsafe_b64encode(b64).decode("utf-8") + zs = (bname + ext).encode("utf-8", "replace") + zs = hashlib.sha512(zs).digest()[:12] + b64 = base64.urlsafe_b64encode(zs).decode("utf-8") badlen = len(fname) while len(fname) >= badlen: @@ -608,8 +642,8 @@ def ren_open(fname, *args, **kwargs): class MultipartParser(object): - def __init__(self, log_func, sr, http_headers): - self.sr = sr # type: Unrecv + def __init__(self, log_func: NamedLogger, sr: Unrecv, http_headers: dict[str, str]): + self.sr = sr self.log = log_func self.headers = http_headers @@ -622,10 +656,14 @@ class MultipartParser(object): r'^content-disposition:(?: *|.*; *)filename="(.*)"', re.IGNORECASE ) - self.boundary = None - self.gen = None + self.boundary = b"" + self.gen: Optional[ + Generator[ + tuple[str, Optional[str], Generator[bytes, None, None]], None, None + ] + ] = None - def _read_header(self): + def _read_header(self) -> tuple[str, Optional[str]]: """ returns [fieldname, filename] after eating a block of multipart headers while doing a decent job at dealing with the absolute mess that is @@ -641,7 +679,8 @@ class MultipartParser(object): # rfc-7578 overrides rfc-2388 so this is not-impl # (opera >=9 <11.10 is the only thing i've ever seen use it) raise Pebkac( - "you can't use that browser to upload multiple files at once" + 400, + "you can't use that browser to upload multiple files at once", ) continue @@ -655,12 +694,12 @@ class MultipartParser(object): raise Pebkac(400, "not form-data: {}".format(ln)) try: - field = self.re_cdisp_field.match(ln).group(1) + field = self.re_cdisp_field.match(ln).group(1) # type: ignore except: raise Pebkac(400, "missing field name: {}".format(ln)) try: - fn = self.re_cdisp_file.match(ln).group(1) + fn = self.re_cdisp_file.match(ln).group(1) # type: ignore except: # this is not a file upload, we're done return field, None @@ -687,11 +726,10 @@ class MultipartParser(object): esc = False for ch in fn: if esc: - if ch in ['"', "\\"]: - ret += '"' - else: - ret += esc + ch esc = False + if ch not in ['"', "\\"]: + ret += "\\" + ret += ch elif ch == "\\": esc = True elif ch == '"': @@ -699,9 +737,11 @@ class MultipartParser(object): else: ret += ch - return [field, ret] + return field, ret - def _read_data(self): + raise Pebkac(400, "server expected a multipart header but you never sent one") + + def _read_data(self) -> Generator[bytes, None, None]: blen = len(self.boundary) bufsz = 32 * 1024 while True: @@ -748,7 +788,9 @@ class MultipartParser(object): yield buf - def _run_gen(self): + def _run_gen( + self, + ) -> Generator[tuple[str, Optional[str], Generator[bytes, None, None]], None, None]: """ yields [fieldname, unsanitized_filename, fieldvalue] where fieldvalue yields chunks of data @@ -756,7 +798,7 @@ class MultipartParser(object): run = True while run: fieldname, filename = self._read_header() - yield [fieldname, filename, self._read_data()] + yield (fieldname, filename, self._read_data()) tail = self.sr.recv_ex(2, False) @@ -766,19 +808,19 @@ class MultipartParser(object): run = False if tail != b"\r\n": - m = "protocol error after field value: want b'\\r\\n', got {!r}" - raise Pebkac(400, m.format(tail)) + t = "protocol error after field value: want b'\\r\\n', got {!r}" + raise Pebkac(400, t.format(tail)) - def _read_value(self, iterator, max_len): + def _read_value(self, iterable: Iterable[bytes], max_len: int) -> bytes: ret = b"" - for buf in iterator: + for buf in iterable: ret += buf if len(ret) > max_len: raise Pebkac(400, "field length is too long") return ret - def parse(self): + def parse(self) -> None: # spec says there might be junk before the first boundary, # can't have the leading \r\n if that's not the case self.boundary = b"--" + get_boundary(self.headers).encode("utf-8") @@ -793,11 +835,12 @@ class MultipartParser(object): self.boundary = b"\r\n" + self.boundary self.gen = self._run_gen() - def require(self, field_name, max_len): + def require(self, field_name: str, max_len: int) -> str: """ returns the value of the next field in the multipart body, raises if the field name is not as expected """ + assert self.gen p_field, _, p_data = next(self.gen) if p_field != field_name: raise Pebkac( @@ -806,14 +849,15 @@ class MultipartParser(object): return self._read_value(p_data, max_len).decode("utf-8", "surrogateescape") - def drop(self): + def drop(self) -> None: """discards the remaining multipart body""" + assert self.gen for _, _, data in self.gen: for _ in data: pass -def get_boundary(headers): +def get_boundary(headers: dict[str, str]) -> str: # boundaries contain a-z A-Z 0-9 ' ( ) + _ , - . / : = ? # (whitespace allowed except as the last char) ptn = r"^multipart/form-data *; *(.*; *)?boundary=([^;]+)" @@ -825,14 +869,14 @@ def get_boundary(headers): return m.group(2) -def read_header(sr): +def read_header(sr: Unrecv) -> list[str]: ret = b"" while True: try: ret += sr.recv(1024) except: if not ret: - return None + return [] raise Pebkac( 400, @@ -853,7 +897,7 @@ def read_header(sr): return ret[:ofs].decode("utf-8", "surrogateescape").lstrip("\r\n").split("\r\n") -def gen_filekey(salt, fspath, fsize, inode): +def gen_filekey(salt: str, fspath: str, fsize: int, inode: int) -> str: return base64.urlsafe_b64encode( hashlib.sha512( "{} {} {} {}".format(salt, fspath, fsize, inode).encode("utf-8", "replace") @@ -861,7 +905,7 @@ def gen_filekey(salt, fspath, fsize, inode): ).decode("ascii") -def gencookie(k, v, dur): +def gencookie(k: str, v: str, dur: Optional[int]) -> str: v = v.replace(";", "") if dur: dt = datetime.utcfromtimestamp(time.time() + dur) @@ -872,7 +916,7 @@ def gencookie(k, v, dur): return "{}={}; Path=/; Expires={}; SameSite=Lax".format(k, v, exp) -def humansize(sz, terse=False): +def humansize(sz: float, terse: bool = False) -> str: for unit in ["B", "KiB", "MiB", "GiB", "TiB"]: if sz < 1024: break @@ -887,18 +931,18 @@ def humansize(sz, terse=False): return ret.replace("iB", "").replace(" ", "") -def unhumanize(sz): +def unhumanize(sz: str) -> int: try: - return float(sz) + return int(sz) except: pass - mul = sz[-1:].lower() - mul = {"k": 1024, "m": 1024 * 1024, "g": 1024 * 1024 * 1024}.get(mul, 1) - return float(sz[:-1]) * mul + mc = sz[-1:].lower() + mi = {"k": 1024, "m": 1024 * 1024, "g": 1024 * 1024 * 1024}.get(mc, 1) + return int(float(sz[:-1]) * mi) -def get_spd(nbyte, t0, t=None): +def get_spd(nbyte: int, t0: float, t: Optional[float] = None) -> str: if t is None: t = time.time() @@ -908,7 +952,7 @@ def get_spd(nbyte, t0, t=None): return "{} \033[0m{}/s\033[0m".format(s1, s2) -def s2hms(s, optional_h=False): +def s2hms(s: float, optional_h: bool = False) -> str: s = int(s) h, s = divmod(s, 3600) m, s = divmod(s, 60) @@ -918,7 +962,7 @@ def s2hms(s, optional_h=False): return "{}:{:02}:{:02}".format(h, m, s) -def uncyg(path): +def uncyg(path: str) -> str: if len(path) < 2 or not path.startswith("/"): return path @@ -928,8 +972,8 @@ def uncyg(path): return "{}:\\{}".format(path[1], path[3:]) -def undot(path): - ret = [] +def undot(path: str) -> str: + ret: list[str] = [] for node in path.split("/"): if node in ["", "."]: continue @@ -944,7 +988,7 @@ def undot(path): return "/".join(ret) -def sanitize_fn(fn, ok, bad): +def sanitize_fn(fn: str, ok: str, bad: list[str]) -> str: if "/" not in ok: fn = fn.replace("\\", "/").split("/")[-1] @@ -976,7 +1020,7 @@ def sanitize_fn(fn, ok, bad): return fn.strip() -def relchk(rp): +def relchk(rp: str) -> str: if ANYWIN: if "\n" in rp or "\r" in rp: return "x\nx" @@ -985,8 +1029,10 @@ def relchk(rp): if p != rp: return "[{}]".format(p) + return "" -def absreal(fpath): + +def absreal(fpath: str) -> str: try: return fsdec(os.path.abspath(os.path.realpath(fsenc(fpath)))) except: @@ -999,26 +1045,26 @@ def absreal(fpath): return os.path.abspath(os.path.realpath(fpath)) -def u8safe(txt): +def u8safe(txt: str) -> str: try: return txt.encode("utf-8", "xmlcharrefreplace").decode("utf-8", "replace") except: return txt.encode("utf-8", "replace").decode("utf-8", "replace") -def exclude_dotfiles(filepaths): +def exclude_dotfiles(filepaths: list[str]) -> list[str]: return [x for x in filepaths if not x.split("/")[-1].startswith(".")] -def http_ts(ts): +def http_ts(ts: int) -> str: file_dt = datetime.utcfromtimestamp(ts) return file_dt.strftime(HTTP_TS_FMT) -def html_escape(s, quote=False, crlf=False): +def html_escape(s: str, quot: bool = False, crlf: bool = False) -> str: """html.escape but also newlines""" s = s.replace("&", "&").replace("<", "<").replace(">", ">") - if quote: + if quot: s = s.replace('"', """).replace("'", "'") if crlf: s = s.replace("\r", " ").replace("\n", " ") @@ -1026,10 +1072,10 @@ def html_escape(s, quote=False, crlf=False): return s -def html_bescape(s, quote=False, crlf=False): +def html_bescape(s: bytes, quot: bool = False, crlf: bool = False) -> bytes: """html.escape but bytestrings""" s = s.replace(b"&", b"&").replace(b"<", b"<").replace(b">", b">") - if quote: + if quot: s = s.replace(b'"', b""").replace(b"'", b"'") if crlf: s = s.replace(b"\r", b" ").replace(b"\n", b" ") @@ -1037,18 +1083,20 @@ def html_bescape(s, quote=False, crlf=False): return s -def quotep(txt): +def quotep(txt: str) -> str: """url quoter which deals with bytes correctly""" btxt = w8enc(txt) quot1 = quote(btxt, safe=b"/") if not PY2: - quot1 = quot1.encode("ascii") + quot2 = quot1.encode("ascii") + else: + quot2 = quot1 - quot2 = quot1.replace(b" ", b"+") - return w8dec(quot2) + quot3 = quot2.replace(b" ", b"+") + return w8dec(quot3) -def unquotep(txt): +def unquotep(txt: str) -> str: """url unquoter which deals with bytes correctly""" btxt = w8enc(txt) # btxt = btxt.replace(b"+", b" ") @@ -1056,14 +1104,14 @@ def unquotep(txt): return w8dec(unq2) -def vsplit(vpath): +def vsplit(vpath: str) -> tuple[str, str]: if "/" not in vpath: return "", vpath - return vpath.rsplit("/", 1) + return vpath.rsplit("/", 1) # type: ignore -def w8dec(txt): +def w8dec(txt: bytes) -> str: """decodes filesystem-bytes to wtf8""" if PY2: return surrogateescape.decodefilename(txt) @@ -1071,7 +1119,7 @@ def w8dec(txt): return txt.decode(FS_ENCODING, "surrogateescape") -def w8enc(txt): +def w8enc(txt: str) -> bytes: """encodes wtf8 to filesystem-bytes""" if PY2: return surrogateescape.encodefilename(txt) @@ -1079,12 +1127,12 @@ def w8enc(txt): return txt.encode(FS_ENCODING, "surrogateescape") -def w8b64dec(txt): +def w8b64dec(txt: str) -> str: """decodes base64(filesystem-bytes) to wtf8""" return w8dec(base64.urlsafe_b64decode(txt.encode("ascii"))) -def w8b64enc(txt): +def w8b64enc(txt: str) -> str: """encodes wtf8 to base64(filesystem-bytes)""" return base64.urlsafe_b64encode(w8enc(txt)).decode("ascii") @@ -1102,8 +1150,8 @@ else: fsdec = w8dec -def s3enc(mem_cur, rd, fn): - ret = [] +def s3enc(mem_cur: "sqlite3.Cursor", rd: str, fn: str) -> tuple[str, str]: + ret: list[str] = [] for v in [rd, fn]: try: mem_cur.execute("select * from a where b = ?", (v,)) @@ -1112,10 +1160,10 @@ def s3enc(mem_cur, rd, fn): ret.append("//" + w8b64enc(v)) # self.log("mojien [{}] {}".format(v, ret[-1][2:])) - return tuple(ret) + return ret[0], ret[1] -def s3dec(rd, fn): +def s3dec(rd: str, fn: str) -> tuple[str, str]: ret = [] for v in [rd, fn]: if v.startswith("//"): @@ -1124,12 +1172,12 @@ def s3dec(rd, fn): else: ret.append(v) - return tuple(ret) + return ret[0], ret[1] -def atomic_move(src, dst): - src = fsenc(src) - dst = fsenc(dst) +def atomic_move(usrc: str, udst: str) -> None: + src = fsenc(usrc) + dst = fsenc(udst) if not PY2: os.replace(src, dst) else: @@ -1139,7 +1187,7 @@ def atomic_move(src, dst): os.rename(src, dst) -def read_socket(sr, total_size): +def read_socket(sr: Unrecv, total_size: int) -> Generator[bytes, None, None]: remains = total_size while remains > 0: bufsz = 32 * 1024 @@ -1149,14 +1197,14 @@ def read_socket(sr, total_size): try: buf = sr.recv(bufsz) except OSError: - m = "client d/c during binary post after {} bytes, {} bytes remaining" - raise Pebkac(400, m.format(total_size - remains, remains)) + t = "client d/c during binary post after {} bytes, {} bytes remaining" + raise Pebkac(400, t.format(total_size - remains, remains)) remains -= len(buf) yield buf -def read_socket_unbounded(sr): +def read_socket_unbounded(sr: Unrecv) -> Generator[bytes, None, None]: try: while True: yield sr.recv(32 * 1024) @@ -1164,7 +1212,9 @@ def read_socket_unbounded(sr): return -def read_socket_chunked(sr, log=None): +def read_socket_chunked( + sr: Unrecv, log: Optional[NamedLogger] = None +) -> Generator[bytes, None, None]: err = "upload aborted: expected chunk length, got [{}] |{}| instead" while True: buf = b"" @@ -1191,8 +1241,8 @@ def read_socket_chunked(sr, log=None): if x == b"\r\n": return - m = "protocol error after final chunk: want b'\\r\\n', got {!r}" - raise Pebkac(400, m.format(x)) + t = "protocol error after final chunk: want b'\\r\\n', got {!r}" + raise Pebkac(400, t.format(x)) if log: log("receiving {} byte chunk".format(chunklen)) @@ -1202,11 +1252,11 @@ def read_socket_chunked(sr, log=None): x = sr.recv_ex(2, False) if x != b"\r\n": - m = "protocol error in chunk separator: want b'\\r\\n', got {!r}" - raise Pebkac(400, m.format(x)) + t = "protocol error in chunk separator: want b'\\r\\n', got {!r}" + raise Pebkac(400, t.format(x)) -def yieldfile(fn): +def yieldfile(fn: str) -> Generator[bytes, None, None]: with open(fsenc(fn), "rb", 512 * 1024) as f: while True: buf = f.read(64 * 1024) @@ -1216,7 +1266,11 @@ def yieldfile(fn): yield buf -def hashcopy(fin, fout, slp=0): +def hashcopy( + fin: Union[typing.BinaryIO, Generator[bytes, None, None]], + fout: Union[typing.BinaryIO, typing.IO[Any]], + slp: int = 0, +) -> tuple[int, str, str]: hashobj = hashlib.sha512() tlen = 0 for buf in fin: @@ -1232,7 +1286,15 @@ def hashcopy(fin, fout, slp=0): return tlen, hashobj.hexdigest(), digest_b64 -def sendfile_py(log, lower, upper, f, s, bufsz, slp): +def sendfile_py( + log: NamedLogger, + lower: int, + upper: int, + f: typing.BinaryIO, + s: socket.socket, + bufsz: int, + slp: int, +) -> int: remains = upper - lower f.seek(lower) while remains > 0: @@ -1252,25 +1314,37 @@ def sendfile_py(log, lower, upper, f, s, bufsz, slp): return 0 -def sendfile_kern(log, lower, upper, f, s, bufsz, slp): +def sendfile_kern( + log: NamedLogger, + lower: int, + upper: int, + f: typing.BinaryIO, + s: socket.socket, + bufsz: int, + slp: int, +) -> int: out_fd = s.fileno() in_fd = f.fileno() ofs = lower - stuck = None + stuck = 0.0 while ofs < upper: stuck = stuck or time.time() try: req = min(2 ** 30, upper - ofs) select.select([], [out_fd], [], 10) n = os.sendfile(out_fd, in_fd, ofs, req) - stuck = None - except Exception as ex: + stuck = 0 + except OSError as ex: d = time.time() - stuck log("sendfile stuck for {:.3f} sec: {!r}".format(d, ex)) if d < 3600 and ex.errno == 11: # eagain continue n = 0 + except Exception as ex: + n = 0 + d = time.time() - stuck + log("sendfile failed after {:.3f} sec: {!r}".format(d, ex)) if n <= 0: return upper - ofs @@ -1281,7 +1355,9 @@ def sendfile_kern(log, lower, upper, f, s, bufsz, slp): return 0 -def statdir(logger, scandir, lstat, top): +def statdir( + logger: Optional[RootLogger], scandir: bool, lstat: bool, top: str +) -> Generator[tuple[str, os.stat_result], None, None]: if lstat and ANYWIN: lstat = False @@ -1295,30 +1371,42 @@ def statdir(logger, scandir, lstat, top): with os.scandir(btop) as dh: for fh in dh: try: - yield [fsdec(fh.name), fh.stat(follow_symlinks=not lstat)] + yield (fsdec(fh.name), fh.stat(follow_symlinks=not lstat)) except Exception as ex: + if not logger: + continue + logger(src, "[s] {} @ {}".format(repr(ex), fsdec(fh.path)), 6) else: src = "listdir" - fun = os.lstat if lstat else os.stat + fun: Any = os.lstat if lstat else os.stat for name in os.listdir(btop): abspath = os.path.join(btop, name) try: - yield [fsdec(name), fun(abspath)] + yield (fsdec(name), fun(abspath)) except Exception as ex: + if not logger: + continue + logger(src, "[s] {} @ {}".format(repr(ex), fsdec(abspath)), 6) except Exception as ex: - logger(src, "{} @ {}".format(repr(ex), top), 1) + t = "{} @ {}".format(repr(ex), top) + if logger: + logger(src, t, 1) + else: + print(t) -def rmdirs(logger, scandir, lstat, top, depth): +def rmdirs( + logger: RootLogger, scandir: bool, lstat: bool, top: str, depth: int +) -> tuple[list[str], list[str]]: if not os.path.exists(fsenc(top)) or not os.path.isdir(fsenc(top)): top = os.path.dirname(top) depth -= 1 - dirs = statdir(logger, scandir, lstat, top) - dirs = [x[0] for x in dirs if stat.S_ISDIR(x[1].st_mode)] + stats = statdir(logger, scandir, lstat, top) + dirs = [x[0] for x in stats if stat.S_ISDIR(x[1].st_mode)] dirs = [os.path.join(top, x) for x in dirs] ok = [] ng = [] @@ -1337,7 +1425,7 @@ def rmdirs(logger, scandir, lstat, top, depth): return ok, ng -def unescape_cookie(orig): +def unescape_cookie(orig: str) -> str: # mw=idk; doot=qwe%2Crty%3Basd+fgh%2Bjkl%25zxc%26vbn # qwe,rty;asd fgh+jkl%zxc&vbn ret = "" esc = "" @@ -1365,7 +1453,7 @@ def unescape_cookie(orig): return ret -def guess_mime(url, fallback="application/octet-stream"): +def guess_mime(url: str, fallback: str = "application/octet-stream") -> str: try: _, ext = url.rsplit(".", 1) except: @@ -1387,7 +1475,9 @@ def guess_mime(url, fallback="application/octet-stream"): return ret -def runcmd(argv, timeout=None, **ka): +def runcmd( + argv: Union[list[bytes], list[str]], timeout: Optional[int] = None, **ka: Any +) -> tuple[int, str, str]: p = sp.Popen(argv, stdout=sp.PIPE, stderr=sp.PIPE, **ka) if not timeout or PY2: stdout, stderr = p.communicate() @@ -1400,10 +1490,10 @@ def runcmd(argv, timeout=None, **ka): stdout = stdout.decode("utf-8", "replace") stderr = stderr.decode("utf-8", "replace") - return [p.returncode, stdout, stderr] + return p.returncode, stdout, stderr -def chkcmd(argv, **ka): +def chkcmd(argv: Union[list[bytes], list[str]], **ka: Any) -> tuple[str, str]: ok, sout, serr = runcmd(argv, **ka) if ok != 0: retchk(ok, argv, serr) @@ -1412,7 +1502,7 @@ def chkcmd(argv, **ka): return sout, serr -def mchkcmd(argv, timeout=10): +def mchkcmd(argv: Union[list[bytes], list[str]], timeout: int = 10) -> None: if PY2: with open(os.devnull, "wb") as f: rv = sp.call(argv, stdout=f, stderr=f) @@ -1423,7 +1513,14 @@ def mchkcmd(argv, timeout=10): raise sp.CalledProcessError(rv, (argv[0], b"...", argv[-1])) -def retchk(rc, cmd, serr, logger=None, color=None, verbose=False): +def retchk( + rc: int, + cmd: Union[list[bytes], list[str]], + serr: str, + logger: Optional[NamedLogger] = None, + color: Union[int, str] = 0, + verbose: bool = False, +) -> None: if rc < 0: rc = 128 - rc @@ -1446,33 +1543,33 @@ def retchk(rc, cmd, serr, logger=None, color=None, verbose=False): s = "invalid retcode" if s: - m = "{} <{}>".format(rc, s) + t = "{} <{}>".format(rc, s) else: - m = str(rc) + t = str(rc) try: - c = " ".join([fsdec(x) for x in cmd]) + c = " ".join([fsdec(x) for x in cmd]) # type: ignore except: c = str(cmd) - m = "error {} from [{}]".format(m, c) + t = "error {} from [{}]".format(t, c) if serr: - m += "\n" + serr + t += "\n" + serr if logger: - logger(m, color) + logger(t, color) else: - raise Exception(m) + raise Exception(t) -def gzip_orig_sz(fn): +def gzip_orig_sz(fn: str) -> int: with open(fsenc(fn), "rb") as f: f.seek(-4, 2) rv = f.read(4) - return sunpack(b"I", rv)[0] + return sunpack(b"I", rv)[0] # type: ignore -def py_desc(): +def py_desc() -> str: interp = platform.python_implementation() py_ver = ".".join([str(x) for x in sys.version_info]) ofs = py_ver.find(".final.") @@ -1487,15 +1584,15 @@ def py_desc(): host_os = platform.system() compiler = platform.python_compiler() - os_ver = re.search(r"([0-9]+\.[0-9\.]+)", platform.version()) - os_ver = os_ver.group(1) if os_ver else "" + m = re.search(r"([0-9]+\.[0-9\.]+)", platform.version()) + os_ver = m.group(1) if m else "" return "{:>9} v{} on {}{} {} [{}]".format( interp, py_ver, host_os, bitness, os_ver, compiler ) -def align_tab(lines): +def align_tab(lines: list[str]) -> list[str]: rows = [] ncols = 0 for ln in lines: @@ -1512,9 +1609,9 @@ def align_tab(lines): class Pebkac(Exception): - def __init__(self, code, msg=None): + def __init__(self, code: int, msg: Optional[str] = None) -> None: super(Pebkac, self).__init__(msg or HTTPCODE[code]) self.code = code - def __repr__(self): + def __repr__(self) -> str: return "Pebkac({}, {})".format(self.code, repr(self.args)) diff --git a/scripts/make-pypi-release.sh b/scripts/make-pypi-release.sh index 6a568051..7e723d4b 100755 --- a/scripts/make-pypi-release.sh +++ b/scripts/make-pypi-release.sh @@ -90,6 +90,15 @@ function have() { have setuptools have wheel have twine + +# remove type hints to support python < 3.9 +rm -rf build/pypi +mkdir -p build/pypi +cp -pR setup.py README.md LICENSE copyparty tests bin scripts/strip_hints build/pypi/ +cd build/pypi +tar --strip-components=2 -xf ../strip-hints-0.1.10.tar.gz strip-hints-0.1.10/src/strip_hints +python3 -c 'from strip_hints.a import uh; uh("copyparty")' + ./setup.py clean2 ./setup.py sdist bdist_wheel --universal diff --git a/scripts/make-sfx.sh b/scripts/make-sfx.sh index 8f9856f3..7402c517 100755 --- a/scripts/make-sfx.sh +++ b/scripts/make-sfx.sh @@ -76,7 +76,7 @@ while [ ! -z "$1" ]; do no-hl) no_hl=1 ; ;; no-dd) no_dd=1 ; ;; no-cm) no_cm=1 ; ;; - fast) zopf=100 ; ;; + fast) zopf= ; ;; lang) shift;langs="$1"; ;; *) help ; ;; esac @@ -106,7 +106,7 @@ tmpdir="$( [ $repack ] && { old="$tmpdir/pe-copyparty" echo "repack of files in $old" - cp -pR "$old/"*{dep-j2,dep-ftp,copyparty} . + cp -pR "$old/"*{j2,ftp,copyparty} . } [ $repack ] || { @@ -130,8 +130,8 @@ tmpdir="$( mv MarkupSafe-*/src/markupsafe . rm -rf MarkupSafe-* markupsafe/_speedups.c - mkdir dep-j2/ - mv {markupsafe,jinja2} dep-j2/ + mkdir j2/ + mv {markupsafe,jinja2} j2/ echo collecting pyftpdlib f="../build/pyftpdlib-1.5.6.tar.gz" @@ -143,8 +143,8 @@ tmpdir="$( mv pyftpdlib-release-*/pyftpdlib . rm -rf pyftpdlib-release-* pyftpdlib/test - mkdir dep-ftp/ - mv pyftpdlib dep-ftp/ + mkdir ftp/ + mv pyftpdlib ftp/ echo collecting asyncore, asynchat for n in asyncore.py asynchat.py; do @@ -154,6 +154,24 @@ tmpdir="$( wget -O$f "$url" || curl -L "$url" >$f) done + # enable this to dynamically remove type hints at startup, + # in case a future python version can use them for performance + true || ( + echo collecting strip-hints + f=../build/strip-hints-0.1.10.tar.gz + [ -e $f ] || + (url=https://files.pythonhosted.org/packages/9c/d4/312ddce71ee10f7e0ab762afc027e07a918f1c0e1be5b0069db5b0e7542d/strip-hints-0.1.10.tar.gz; + wget -O$f "$url" || curl -L "$url" >$f) + + tar -zxf $f + mv strip-hints-0.1.10/src/strip_hints . + rm -rf strip-hints-* strip_hints/import_hooks* + sed -ri 's/[a-z].* as import_hooks$/"""a"""/' strip_hints/*.py + + cp -pR ../scripts/strip_hints/ . + ) + cp -pR ../scripts/py2/ . + # msys2 tar is bad, make the best of it echo collecting source [ $clean ] && { @@ -170,6 +188,9 @@ tmpdir="$( for n in asyncore.py asynchat.py; do awk 'NR<4||NR>27;NR==4{print"# license: https://opensource.org/licenses/ISC\n"}' ../build/$n >copyparty/vend/$n done + + # remove type hints before build instead + (cd copyparty; python3 ../../scripts/strip_hints/a.py; rm uh) } ver= @@ -274,17 +295,23 @@ rm have tmv "$f" done -[ $repack ] || -find | grep -E '\.py$' | - grep -vE '__version__' | - tr '\n' '\0' | - xargs -0 "$pybin" ../scripts/uncomment.py +[ $repack ] || { + # uncomment + find | grep -E '\.py$' | + grep -vE '__version__' | + tr '\n' '\0' | + xargs -0 "$pybin" ../scripts/uncomment.py -f=dep-j2/jinja2/constants.py + # py2-compat + #find | grep -E '\.py$' | while IFS= read -r x; do + # sed -ri '/: TypeAlias = /d' "$x"; done +} + +f=j2/jinja2/constants.py awk '/^LOREM_IPSUM_WORDS/{o=1;print "LOREM_IPSUM_WORDS = u\"a\"";next} !o; /"""/{o=0}' <$f >t tmv "$f" -grep -rLE '^#[^a-z]*coding: utf-8' dep-j2 | +grep -rLE '^#[^a-z]*coding: utf-8' j2 | while IFS= read -r f; do (echo "# coding: utf-8"; cat "$f") >t tmv "$f" @@ -313,7 +340,7 @@ find | grep -E '\.(js|html)$' | while IFS= read -r f; do done gzres() { - command -v pigz && + command -v pigz && [ $zopf ] && pk="pigz -11 -I $zopf" || pk='gzip' @@ -354,7 +381,8 @@ nf=$(ls -1 "$zdir"/arc.* | wc -l) } [ $use_zdir ] && { arcs=("$zdir"/arc.*) - arc="${arcs[$RANDOM % ${#arcs[@]} ] }" + n=$(( $RANDOM % ${#arcs[@]} )) + arc="${arcs[n]}" echo "using $arc" tar -xf "$arc" for f in copyparty/web/*.gz; do @@ -364,7 +392,7 @@ nf=$(ls -1 "$zdir"/arc.* | wc -l) echo gen tarlist -for d in copyparty dep-j2 dep-ftp; do find $d -type f; done | +for d in copyparty j2 ftp py2; do find $d -type f; done | # strip_hints sed -r 's/(.*)\.(.*)/\2 \1/' | LC_ALL=C sort | sed -r 's/([^ ]*) (.*)/\2.\1/' | grep -vE '/list1?$' > list1 diff --git a/scripts/run-tests.sh b/scripts/run-tests.sh index 48aaf075..1977a55a 100755 --- a/scripts/run-tests.sh +++ b/scripts/run-tests.sh @@ -1,13 +1,23 @@ #!/bin/bash set -ex +rm -rf unt +mkdir -p unt/srv +cp -pR copyparty tests unt/ +cd unt +python3 ../scripts/strip_hints/a.py + pids=() for py in python{2,3}; do + PYTHONPATH= + [ $py = python2 ] && PYTHONPATH=../scripts/py2 + export PYTHONPATH + nice $py -m unittest discover -s tests >/dev/null & pids+=($!) done -python3 scripts/test/smoketest.py & +python3 ../scripts/test/smoketest.py & pids+=($!) for pid in ${pids[@]}; do diff --git a/scripts/sfx.py b/scripts/sfx.py index 4723a0da..2cb20c8a 100644 --- a/scripts/sfx.py +++ b/scripts/sfx.py @@ -379,9 +379,20 @@ def run(tmp, j2, ftp): t.daemon = True t.start() - ld = (("", ""), (j2, "dep-j2"), (ftp, "dep-ftp")) + ld = (("", ""), (j2, "j2"), (ftp, "ftp"), (not PY2, "py2")) ld = [os.path.join(tmp, b) for a, b in ld if not a] + # skip 1 + # enable this to dynamically remove type hints at startup, + # in case a future python version can use them for performance + if sys.version_info < (3, 10) and False: + sys.path.insert(0, ld[0]) + + from strip_hints.a import uh + + uh(tmp + "/copyparty") + # skip 0 + if any([re.match(r"^-.*j[0-9]", x) for x in sys.argv]): run_s(ld) else: diff --git a/scripts/sfx.sh b/scripts/sfx.sh index 1f53c8db..1496f970 100644 --- a/scripts/sfx.sh +++ b/scripts/sfx.sh @@ -47,7 +47,7 @@ grep -E '/(python|pypy)[0-9\.-]*$' >$dir/pys || true printf '\033[1;30mlooking for jinja2 in [%s]\033[0m\n' "$_py" >&2 $_py -c 'import jinja2' 2>/dev/null || continue printf '%s\n' "$_py" - mv $dir/{,x.}dep-j2 + mv $dir/{,x.}j2 break done)" diff --git a/scripts/strip_hints/a.py b/scripts/strip_hints/a.py new file mode 100644 index 00000000..06bf4f6b --- /dev/null +++ b/scripts/strip_hints/a.py @@ -0,0 +1,57 @@ +# coding: utf-8 +from __future__ import print_function, unicode_literals + +import re +import os +import sys +from strip_hints import strip_file_to_string + + +# list unique types used in hints: +# rm -rf unt && cp -pR copyparty unt && (cd unt && python3 ../scripts/strip_hints/a.py) +# diff -wNarU1 copyparty unt | grep -E '^\-' | sed -r 's/[^][, ]+://g; s/[^][, ]+[[(]//g; s/[],()<>{} -]/\n/g' | grep -E .. | sort | uniq -c | sort -n + + +def pr(m): + sys.stderr.write(m) + sys.stderr.flush() + + +def uh(top): + if os.path.exists(top + "/uh"): + return + + libs = "typing|types|collections\.abc" + ptn = re.compile(r"^(\s*)(from (?:{0}) import |import (?:{0})\b).*".format(libs)) + + # pr("building support for your python ver") + pr("unhinting") + for (dp, _, fns) in os.walk(top): + for fn in fns: + if not fn.endswith(".py"): + continue + + pr(".") + fp = os.path.join(dp, fn) + cs = strip_file_to_string(fp, no_ast=True, to_empty=True) + + # remove expensive imports too + lns = [] + for ln in cs.split("\n"): + m = ptn.match(ln) + if m: + ln = m.group(1) + "raise Exception()" + + lns.append(ln) + + cs = "\n".join(lns) + with open(fp, "wb") as f: + f.write(cs.encode("utf-8")) + + pr("k\n\n") + with open(top + "/uh", "wb") as f: + f.write(b"a") + + +if __name__ == "__main__": + uh(".") diff --git a/scripts/test/race.py b/scripts/test/race.py index 09922e60..77ef0e46 100644 --- a/scripts/test/race.py +++ b/scripts/test/race.py @@ -58,13 +58,13 @@ class CState(threading.Thread): remotes.append("?") remotes_ok = False - m = [] + ta = [] for conn, remote in zip(self.cs, remotes): stage = len(conn.st) - m.append(f"\033[3{colors[stage]}m{remote}") + ta.append(f"\033[3{colors[stage]}m{remote}") - m = " ".join(m) - print(f"{m}\033[0m\n\033[A", end="") + t = " ".join(ta) + print(f"{t}\033[0m\n\033[A", end="") def allget(cs, urls): diff --git a/scripts/test/smoketest.py b/scripts/test/smoketest.py index 8508f127..a37fc44a 100644 --- a/scripts/test/smoketest.py +++ b/scripts/test/smoketest.py @@ -72,6 +72,8 @@ def tc1(vflags): for _ in range(10): try: os.mkdir(td) + if os.path.exists(td): + break except: time.sleep(0.1) # win10 diff --git a/tests/test_vfs.py b/tests/test_vfs.py index cb7f08fe..4818d180 100644 --- a/tests/test_vfs.py +++ b/tests/test_vfs.py @@ -85,7 +85,7 @@ class TestVFS(unittest.TestCase): pass def assertAxs(self, dct, lst): - t1 = list(sorted(dct.keys())) + t1 = list(sorted(dct)) t2 = list(sorted(lst)) self.assertEqual(t1, t2) @@ -208,10 +208,10 @@ class TestVFS(unittest.TestCase): self.assertEqual(n.realpath, os.path.join(td, "a")) self.assertAxs(n.axs.uread, ["*"]) self.assertAxs(n.axs.uwrite, []) - self.assertEqual(vfs.can_access("/", "*"), [False, False, False, False, False]) - self.assertEqual(vfs.can_access("/", "k"), [True, True, False, False, False]) - self.assertEqual(vfs.can_access("/a", "*"), [True, False, False, False, False]) - self.assertEqual(vfs.can_access("/a", "k"), [True, False, False, False, False]) + self.assertEqual(vfs.can_access("/", "*"), (False, False, False, False, False)) + self.assertEqual(vfs.can_access("/", "k"), (True, True, False, False, False)) + self.assertEqual(vfs.can_access("/a", "*"), (True, False, False, False, False)) + self.assertEqual(vfs.can_access("/a", "k"), (True, False, False, False, False)) # breadth-first construction vfs = AuthSrv( @@ -279,7 +279,7 @@ class TestVFS(unittest.TestCase): n = au.vfs # root was not defined, so PWD with no access to anyone self.assertEqual(n.vpath, "") - self.assertEqual(n.realpath, None) + self.assertEqual(n.realpath, "") self.assertAxs(n.axs.uread, []) self.assertAxs(n.axs.uwrite, []) self.assertEqual(len(n.nodes), 1) diff --git a/tests/util.py b/tests/util.py index 55ec58ea..92ecaa88 100644 --- a/tests/util.py +++ b/tests/util.py @@ -90,7 +90,10 @@ def get_ramdisk(): class NullBroker(object): - def put(*args): + def say(*args): + pass + + def ask(*args): pass