From 438384425a96f3f69b4620e25945fd241e617c4a Mon Sep 17 00:00:00 2001
From: ed
Date: Thu, 16 Jun 2022 01:07:15 +0200
Subject: [PATCH] add types, isort, errorhandling
---
README.md | 9 +-
bin/mtag/image-noexif.py | 1 -
bin/up2k.py | 22 +-
contrib/systemd/copyparty.service | 2 +-
copyparty/__init__.py | 31 +-
copyparty/__main__.py | 100 ++--
copyparty/authsrv.py | 574 +++++++++++---------
copyparty/bos/bos.py | 33 +-
copyparty/bos/path.py | 21 +-
copyparty/broker_mp.py | 62 ++-
copyparty/broker_mpw.py | 61 ++-
copyparty/broker_thr.py | 57 +-
copyparty/broker_util.py | 41 +-
copyparty/ftpd.py | 194 ++++---
copyparty/httpcli.py | 696 +++++++++++++++----------
copyparty/httpconn.py | 94 ++--
copyparty/httpsrv.py | 107 ++--
copyparty/ico.py | 24 +-
copyparty/mtag.py | 142 ++---
copyparty/star.py | 54 +-
copyparty/stolen/surrogateescape.py | 23 +-
copyparty/sutil.py | 23 +-
copyparty/svchub.py | 113 ++--
copyparty/szip.py | 77 ++-
copyparty/tcpsrv.py | 75 +--
copyparty/th_cli.py | 28 +-
copyparty/th_srv.py | 120 +++--
copyparty/u2idx.py | 89 ++--
copyparty/up2k.py | 781 ++++++++++++++++------------
copyparty/util.py | 513 ++++++++++--------
scripts/make-pypi-release.sh | 9 +
scripts/make-sfx.sh | 60 ++-
scripts/run-tests.sh | 12 +-
scripts/sfx.py | 13 +-
scripts/sfx.sh | 2 +-
scripts/strip_hints/a.py | 57 ++
scripts/test/race.py | 8 +-
scripts/test/smoketest.py | 2 +
tests/test_vfs.py | 12 +-
tests/util.py | 5 +-
40 files changed, 2597 insertions(+), 1750 deletions(-)
create mode 100644 scripts/strip_hints/a.py
diff --git a/README.md b/README.md
index 38c60727..6493f0d2 100644
--- a/README.md
+++ b/README.md
@@ -1200,15 +1200,18 @@ journalctl -aS '48 hour ago' -u copyparty | grep -C10 FILENAME | tee bug.log
## dev env setup
-mostly optional; if you need a working env for vscode or similar
+you need python 3.9 or newer due to type hints
+
+the rest is mostly optional; if you need a working env for vscode or similar
```sh
python3 -m venv .venv
. .venv/bin/activate
-pip install jinja2 # mandatory
+pip install jinja2 strip_hints # MANDATORY
pip install mutagen # audio metadata
+pip install pyftpdlib # ftp server
pip install Pillow pyheif-pillow-opener pillow-avif-plugin # thumbnails
-pip install black==21.12b0 bandit pylint flake8 # vscode tooling
+pip install black==21.12b0 click==8.0.2 bandit pylint flake8 isort mypy # vscode tooling
```
diff --git a/bin/mtag/image-noexif.py b/bin/mtag/image-noexif.py
index d5392b00..bc009c17 100644
--- a/bin/mtag/image-noexif.py
+++ b/bin/mtag/image-noexif.py
@@ -43,7 +43,6 @@ PS: this requires e2ts to be functional,
import os
import sys
-import time
import filecmp
import subprocess as sp
diff --git a/bin/up2k.py b/bin/up2k.py
index 8cf904c4..23d35ea3 100755
--- a/bin/up2k.py
+++ b/bin/up2k.py
@@ -77,15 +77,15 @@ class File(object):
self.up_b = 0 # type: int
self.up_c = 0 # type: int
- # m = "size({}) lmod({}) top({}) rel({}) abs({}) name({})\n"
- # eprint(m.format(self.size, self.lmod, self.top, self.rel, self.abs, self.name))
+ # t = "size({}) lmod({}) top({}) rel({}) abs({}) name({})\n"
+ # eprint(t.format(self.size, self.lmod, self.top, self.rel, self.abs, self.name))
class FileSlice(object):
"""file-like object providing a fixed window into a file"""
def __init__(self, file, cid):
- # type: (File, str) -> FileSlice
+ # type: (File, str) -> None
self.car, self.len = file.kchunks[cid]
self.cdr = self.car + self.len
@@ -216,8 +216,8 @@ class CTermsize(object):
eprint("\033[s\033[r\033[u")
else:
self.g = 1 + self.h - margin
- m = "{0}\033[{1}A".format("\n" * margin, margin)
- eprint("{0}\033[s\033[1;{1}r\033[u".format(m, self.g - 1))
+ t = "{0}\033[{1}A".format("\n" * margin, margin)
+ eprint("{0}\033[s\033[1;{1}r\033[u".format(t, self.g - 1))
ss = CTermsize()
@@ -597,8 +597,8 @@ class Ctl(object):
if "/" in name:
name = "\033[36m{0}\033[0m/{1}".format(*name.rsplit("/", 1))
- m = "{0:6.1f}% {1} {2}\033[K"
- txt += m.format(p, self.nfiles - f, name)
+ t = "{0:6.1f}% {1} {2}\033[K"
+ txt += t.format(p, self.nfiles - f, name)
txt += "\033[{0}H ".format(ss.g + 2)
else:
@@ -618,8 +618,8 @@ class Ctl(object):
nleft = self.nfiles - self.up_f
tail = "\033[K\033[u" if VT100 else "\r"
- m = "{0} eta @ {1}/s, {2}, {3}# left".format(eta, spd, sleft, nleft)
- eprint(txt + "\033]0;{0}\033\\\r{0}{1}".format(m, tail))
+ t = "{0} eta @ {1}/s, {2}, {3}# left".format(eta, spd, sleft, nleft)
+ eprint(txt + "\033]0;{0}\033\\\r{0}{1}".format(t, tail))
def cleanup_vt100(self):
ss.scroll_region(None)
@@ -721,8 +721,8 @@ class Ctl(object):
if search:
if hs:
for hit in hs:
- m = "found: {0}\n {1}{2}\n"
- print(m.format(upath, burl, hit["rp"]), end="")
+ t = "found: {0}\n {1}{2}\n"
+ print(t.format(upath, burl, hit["rp"]), end="")
else:
print("NOT found: {0}\n".format(upath), end="")
diff --git a/contrib/systemd/copyparty.service b/contrib/systemd/copyparty.service
index e8e02f90..fd3f9efc 100644
--- a/contrib/systemd/copyparty.service
+++ b/contrib/systemd/copyparty.service
@@ -4,7 +4,7 @@
# installation:
# cp -pv copyparty.service /etc/systemd/system
# restorecon -vr /etc/systemd/system/copyparty.service
-# firewall-cmd --permanent --add-port={80,443,3923}/tcp
+# firewall-cmd --permanent --add-port={80,443,3923}/tcp # --zone=libvirt
# firewall-cmd --reload
# systemctl daemon-reload && systemctl enable --now copyparty
#
diff --git a/copyparty/__init__.py b/copyparty/__init__.py
index 17513708..594363eb 100644
--- a/copyparty/__init__.py
+++ b/copyparty/__init__.py
@@ -1,21 +1,30 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import platform
-import time
-import sys
import os
+import platform
+import sys
+import time
+
+try:
+ from collections.abc import Callable
+
+ from typing import TYPE_CHECKING, Any
+except:
+ TYPE_CHECKING = False
PY2 = sys.version_info[0] == 2
if PY2:
sys.dont_write_bytecode = True
- unicode = unicode
+ unicode = unicode # noqa: F821 # pylint: disable=undefined-variable,self-assigning-variable
else:
unicode = str
-WINDOWS = False
-if platform.system() == "Windows":
- WINDOWS = [int(x) for x in platform.version().split(".")]
+WINDOWS: Any = (
+ [int(x) for x in platform.version().split(".")]
+ if platform.system() == "Windows"
+ else False
+)
VT100 = not WINDOWS or WINDOWS >= [10, 0, 14393]
# introduced in anniversary update
@@ -25,8 +34,8 @@ ANYWIN = WINDOWS or sys.platform in ["msys"]
MACOS = platform.system() == "Darwin"
-def get_unixdir():
- paths = [
+def get_unixdir() -> str:
+ paths: list[tuple[Callable[..., str], str]] = [
(os.environ.get, "XDG_CONFIG_HOME"),
(os.path.expanduser, "~/.config"),
(os.environ.get, "TMPDIR"),
@@ -43,7 +52,7 @@ def get_unixdir():
continue
p = os.path.normpath(p)
- chk(p)
+ chk(p) # type: ignore
p = os.path.join(p, "copyparty")
if not os.path.isdir(p):
os.mkdir(p)
@@ -56,7 +65,7 @@ def get_unixdir():
class EnvParams(object):
- def __init__(self):
+ def __init__(self) -> None:
self.t0 = time.time()
self.mod = os.path.dirname(os.path.realpath(__file__))
if self.mod.endswith("__init__"):
diff --git a/copyparty/__main__.py b/copyparty/__main__.py
index 87203f87..5412a6c2 100644
--- a/copyparty/__main__.py
+++ b/copyparty/__main__.py
@@ -8,35 +8,42 @@ __copyright__ = 2019
__license__ = "MIT"
__url__ = "https://github.com/9001/copyparty/"
-import re
-import os
-import sys
-import time
-import shutil
+import argparse
import filecmp
import locale
-import argparse
+import os
+import re
+import shutil
+import sys
import threading
+import time
import traceback
from textwrap import dedent
-from .__init__ import E, WINDOWS, ANYWIN, VT100, PY2, unicode
-from .__version__ import S_VERSION, S_BUILD_DT, CODENAME
-from .svchub import SvcHub
-from .util import py_desc, align_tab, IMPLICATIONS, ansi_re, min_ex
+from .__init__ import ANYWIN, PY2, VT100, WINDOWS, E, unicode
+from .__version__ import CODENAME, S_BUILD_DT, S_VERSION
from .authsrv import re_vol
+from .svchub import SvcHub
+from .util import IMPLICATIONS, align_tab, ansi_re, min_ex, py_desc
-HAVE_SSL = True
try:
+ from types import FrameType
+
+ from typing import Any, Optional
+except:
+ pass
+
+try:
+ HAVE_SSL = True
import ssl
except:
HAVE_SSL = False
-printed = []
+printed: list[str] = []
class RiceFormatter(argparse.HelpFormatter):
- def _get_help_string(self, action):
+ def _get_help_string(self, action: argparse.Action) -> str:
"""
same as ArgumentDefaultsHelpFormatter(HelpFormatter)
except the help += [...] line now has colors
@@ -45,27 +52,27 @@ class RiceFormatter(argparse.HelpFormatter):
if not VT100:
fmt = " (default: %(default)s)"
- ret = action.help
- if "%(default)" not in action.help:
+ ret = str(action.help)
+ if "%(default)" not in ret:
if action.default is not argparse.SUPPRESS:
defaulting_nargs = [argparse.OPTIONAL, argparse.ZERO_OR_MORE]
if action.option_strings or action.nargs in defaulting_nargs:
ret += fmt
return ret
- def _fill_text(self, text, width, indent):
+ def _fill_text(self, text: str, width: int, indent: str) -> str:
"""same as RawDescriptionHelpFormatter(HelpFormatter)"""
return "".join(indent + line + "\n" for line in text.splitlines())
class Dodge11874(RiceFormatter):
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["width"] = 9003
super(Dodge11874, self).__init__(*args, **kwargs)
-def lprint(*a, **ka):
- txt = " ".join(unicode(x) for x in a) + ka.get("end", "\n")
+def lprint(*a: Any, **ka: Any) -> None:
+ txt: str = " ".join(unicode(x) for x in a) + ka.get("end", "\n")
printed.append(txt)
if not VT100:
txt = ansi_re.sub("", txt)
@@ -73,11 +80,11 @@ def lprint(*a, **ka):
print(txt, **ka)
-def warn(msg):
+def warn(msg: str) -> None:
lprint("\033[1mwarning:\033[0;33m {}\033[0m\n".format(msg))
-def ensure_locale():
+def ensure_locale() -> None:
for x in [
"en_US.UTF-8",
"English_United States.UTF8",
@@ -91,7 +98,7 @@ def ensure_locale():
continue
-def ensure_cert():
+def ensure_cert() -> None:
"""
the default cert (and the entire TLS support) is only here to enable the
crypto.subtle javascript API, which is necessary due to the webkit guys
@@ -117,8 +124,8 @@ def ensure_cert():
# printf 'NO\n.\n.\n.\n.\ncopyparty-insecure\n.\n' | faketime '2000-01-01 00:00:00' openssl req -x509 -sha256 -newkey rsa:2048 -keyout insecure.pem -out insecure.pem -days $((($(printf %d 0x7fffffff)-$(date +%s --date=2000-01-01T00:00:00Z))/(60*60*24))) -nodes && ls -al insecure.pem && openssl x509 -in insecure.pem -text -noout
-def configure_ssl_ver(al):
- def terse_sslver(txt):
+def configure_ssl_ver(al: argparse.Namespace) -> None:
+ def terse_sslver(txt: str) -> str:
txt = txt.lower()
for c in ["_", "v", "."]:
txt = txt.replace(c, "")
@@ -133,8 +140,8 @@ def configure_ssl_ver(al):
flags = [k for k in ssl.__dict__ if ptn.match(k)]
# SSLv2 SSLv3 TLSv1 TLSv1_1 TLSv1_2 TLSv1_3
if "help" in sslver:
- avail = [terse_sslver(x[6:]) for x in flags]
- avail = " ".join(sorted(avail) + ["all"])
+ avail1 = [terse_sslver(x[6:]) for x in flags]
+ avail = " ".join(sorted(avail1) + ["all"])
lprint("\navailable ssl/tls versions:\n " + avail)
sys.exit(0)
@@ -160,7 +167,7 @@ def configure_ssl_ver(al):
# think i need that beer now
-def configure_ssl_ciphers(al):
+def configure_ssl_ciphers(al: argparse.Namespace) -> None:
ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
if al.ssl_ver:
ctx.options &= ~al.ssl_flags_en
@@ -184,8 +191,8 @@ def configure_ssl_ciphers(al):
sys.exit(0)
-def args_from_cfg(cfg_path):
- ret = []
+def args_from_cfg(cfg_path: str) -> list[str]:
+ ret: list[str] = []
skip = False
with open(cfg_path, "rb") as f:
for ln in [x.decode("utf-8").strip() for x in f]:
@@ -210,29 +217,30 @@ def args_from_cfg(cfg_path):
return ret
-def sighandler(sig=None, frame=None):
+def sighandler(sig: Optional[int] = None, frame: Optional[FrameType] = None) -> None:
msg = [""] * 5
for th in threading.enumerate():
+ stk = sys._current_frames()[th.ident] # type: ignore
msg.append(str(th))
- msg.extend(traceback.format_stack(sys._current_frames()[th.ident]))
+ msg.extend(traceback.format_stack(stk))
msg.append("\n")
print("\n".join(msg))
-def disable_quickedit():
- import ctypes
+def disable_quickedit() -> None:
import atexit
+ import ctypes
from ctypes import wintypes
- def ecb(ok, fun, args):
+ def ecb(ok: bool, fun: Any, args: list[Any]) -> list[Any]:
if not ok:
- err = ctypes.get_last_error()
+ err: int = ctypes.get_last_error() # type: ignore
if err:
- raise ctypes.WinError(err)
+ raise ctypes.WinError(err) # type: ignore
return args
- k32 = ctypes.WinDLL("kernel32", use_last_error=True)
+ k32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore
if PY2:
wintypes.LPDWORD = ctypes.POINTER(wintypes.DWORD)
@@ -242,14 +250,14 @@ def disable_quickedit():
k32.GetConsoleMode.argtypes = (wintypes.HANDLE, wintypes.LPDWORD)
k32.SetConsoleMode.argtypes = (wintypes.HANDLE, wintypes.DWORD)
- def cmode(out, mode=None):
+ def cmode(out: bool, mode: Optional[int] = None) -> int:
h = k32.GetStdHandle(-11 if out else -10)
if mode:
- return k32.SetConsoleMode(h, mode)
+ return k32.SetConsoleMode(h, mode) # type: ignore
- mode = wintypes.DWORD()
- k32.GetConsoleMode(h, ctypes.byref(mode))
- return mode.value
+ cmode = wintypes.DWORD()
+ k32.GetConsoleMode(h, ctypes.byref(cmode))
+ return cmode.value
# disable quickedit
mode = orig_in = cmode(False)
@@ -268,7 +276,7 @@ def disable_quickedit():
cmode(True, mode | 4)
-def run_argparse(argv, formatter):
+def run_argparse(argv: list[str], formatter: Any) -> argparse.Namespace:
ap = argparse.ArgumentParser(
formatter_class=formatter,
prog="copyparty",
@@ -596,7 +604,7 @@ def run_argparse(argv, formatter):
return ret
-def main(argv=None):
+def main(argv: Optional[list[str]] = None) -> None:
time.strptime("19970815", "%Y%m%d") # python#7980
if WINDOWS:
os.system("rem") # enables colors
@@ -618,7 +626,7 @@ def main(argv=None):
supp = args_from_cfg(v)
argv.extend(supp)
- deprecated = []
+ deprecated: list[tuple[str, str]] = []
for dk, nk in deprecated:
try:
idx = argv.index(dk)
@@ -650,7 +658,7 @@ def main(argv=None):
if not VT100:
al.wintitle = ""
- nstrs = []
+ nstrs: list[str] = []
anymod = False
for ostr in al.v or []:
m = re_vol.match(ostr)
diff --git a/copyparty/authsrv.py b/copyparty/authsrv.py
index 4ca406ee..92612529 100644
--- a/copyparty/authsrv.py
+++ b/copyparty/authsrv.py
@@ -1,44 +1,68 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import re
-import os
-import sys
-import stat
-import time
+import argparse
import base64
import hashlib
+import os
+import re
+import stat
+import sys
import threading
+import time
from datetime import datetime
-from .__init__ import ANYWIN, WINDOWS
+from .__init__ import ANYWIN, TYPE_CHECKING, WINDOWS
+from .bos import bos
from .util import (
IMPLICATIONS,
META_NOBOTS,
+ Pebkac,
+ absreal,
+ fsenc,
+ relchk,
+ statdir,
uncyg,
undot,
- relchk,
unhumanize,
- absreal,
- Pebkac,
- fsenc,
- statdir,
)
-from .bos import bos
+
+try:
+ from collections.abc import Iterable
+
+ import typing
+ from typing import Any, Generator, Optional, Union
+
+ from .util import RootLogger
+except:
+ pass
+
+if TYPE_CHECKING:
+ pass
+ # Vflags: TypeAlias = dict[str, str | bool | float | list[str]]
+ # Vflags: TypeAlias = dict[str, Any]
+ # Mflags: TypeAlias = dict[str, Vflags]
LEELOO_DALLAS = "leeloo_dallas"
class AXS(object):
- def __init__(self, uread=None, uwrite=None, umove=None, udel=None, uget=None):
- self.uread = {} if uread is None else {k: 1 for k in uread}
- self.uwrite = {} if uwrite is None else {k: 1 for k in uwrite}
- self.umove = {} if umove is None else {k: 1 for k in umove}
- self.udel = {} if udel is None else {k: 1 for k in udel}
- self.uget = {} if uget is None else {k: 1 for k in uget}
+ def __init__(
+ self,
+ uread: Optional[Union[list[str], set[str]]] = None,
+ uwrite: Optional[Union[list[str], set[str]]] = None,
+ umove: Optional[Union[list[str], set[str]]] = None,
+ udel: Optional[Union[list[str], set[str]]] = None,
+ uget: Optional[Union[list[str], set[str]]] = None,
+ ) -> None:
+ self.uread: set[str] = set(uread or [])
+ self.uwrite: set[str] = set(uwrite or [])
+ self.umove: set[str] = set(umove or [])
+ self.udel: set[str] = set(udel or [])
+ self.uget: set[str] = set(uget or [])
- def __repr__(self):
+ def __repr__(self) -> str:
return "AXS({})".format(
", ".join(
"{}={!r}".format(k, self.__dict__[k])
@@ -48,33 +72,33 @@ class AXS(object):
class Lim(object):
- def __init__(self):
- self.nups = {} # num tracker
- self.bups = {} # byte tracker list
- self.bupc = {} # byte tracker cache
+ def __init__(self) -> None:
+ self.nups: dict[str, list[float]] = {} # num tracker
+ self.bups: dict[str, list[tuple[float, int]]] = {} # byte tracker list
+ self.bupc: dict[str, int] = {} # byte tracker cache
self.nosub = False # disallow subdirectories
- self.smin = None # filesize min
- self.smax = None # filesize max
+ self.smin = -1 # filesize min
+ self.smax = -1 # filesize max
- self.bwin = None # bytes window
- self.bmax = None # bytes max
- self.nwin = None # num window
- self.nmax = None # num max
+ self.bwin = 0 # bytes window
+ self.bmax = 0 # bytes max
+ self.nwin = 0 # num window
+ self.nmax = 0 # num max
- self.rotn = None # rot num files
- self.rotl = None # rot depth
- self.rotf = None # rot datefmt
- self.rot_re = None # rotf check
+ self.rotn = 0 # rot num files
+ self.rotl = 0 # rot depth
+ self.rotf = "" # rot datefmt
+ self.rot_re = re.compile("") # rotf check
- def set_rotf(self, fmt):
+ def set_rotf(self, fmt: str) -> None:
self.rotf = fmt
r = re.escape(fmt).replace("%Y", "[0-9]{4}").replace("%j", "[0-9]{3}")
r = re.sub("%[mdHMSWU]", "[0-9]{2}", r)
self.rot_re = re.compile("(^|/)" + r + "$")
- def all(self, ip, rem, sz, abspath):
+ def all(self, ip: str, rem: str, sz: float, abspath: str) -> tuple[str, str]:
self.chk_nup(ip)
self.chk_bup(ip)
self.chk_rem(rem)
@@ -87,18 +111,18 @@ class Lim(object):
return ap2, ("{}/{}".format(rem, vp2) if rem else vp2)
- def chk_sz(self, sz):
- if self.smin is not None and sz < self.smin:
+ def chk_sz(self, sz: float) -> None:
+ if self.smin != -1 and sz < self.smin:
raise Pebkac(400, "file too small")
- if self.smax is not None and sz > self.smax:
+ if self.smax != -1 and sz > self.smax:
raise Pebkac(400, "file too big")
- def chk_rem(self, rem):
+ def chk_rem(self, rem: str) -> None:
if self.nosub and rem:
raise Pebkac(500, "no subdirectories allowed")
- def rot(self, path):
+ def rot(self, path: str) -> tuple[str, str]:
if not self.rotf and not self.rotn:
return path, ""
@@ -120,7 +144,7 @@ class Lim(object):
d = ret[len(path) :].strip("/\\").replace("\\", "/")
return ret, d
- def dive(self, path, lvs):
+ def dive(self, path: str, lvs: int) -> Optional[str]:
items = bos.listdir(path)
if not lvs:
@@ -155,14 +179,14 @@ class Lim(object):
return os.path.join(sub, ret)
- def nup(self, ip):
+ def nup(self, ip: str) -> None:
try:
self.nups[ip].append(time.time())
except:
self.nups[ip] = [time.time()]
- def bup(self, ip, nbytes):
- v = [time.time(), nbytes]
+ def bup(self, ip: str, nbytes: int) -> None:
+ v = (time.time(), nbytes)
try:
self.bups[ip].append(v)
self.bupc[ip] += nbytes
@@ -170,7 +194,7 @@ class Lim(object):
self.bups[ip] = [v]
self.bupc[ip] = nbytes
- def chk_nup(self, ip):
+ def chk_nup(self, ip: str) -> None:
if not self.nmax or ip not in self.nups:
return
@@ -182,7 +206,7 @@ class Lim(object):
if len(nups) >= self.nmax:
raise Pebkac(429, "too many uploads")
- def chk_bup(self, ip):
+ def chk_bup(self, ip: str) -> None:
if not self.bmax or ip not in self.bups:
return
@@ -200,35 +224,37 @@ class Lim(object):
class VFS(object):
"""single level in the virtual fs"""
- def __init__(self, log, realpath, vpath, axs, flags):
+ def __init__(
+ self,
+ log: Optional[RootLogger],
+ realpath: str,
+ vpath: str,
+ axs: AXS,
+ flags: dict[str, Any],
+ ) -> None:
self.log = log
self.realpath = realpath # absolute path on host filesystem
self.vpath = vpath # absolute path in the virtual filesystem
- self.axs = axs # type: AXS
+ self.axs = axs
self.flags = flags # config options
- self.nodes = {} # child nodes
- self.histtab = None # all realpath->histpath
- self.dbv = None # closest full/non-jump parent
- self.lim = None # type: Lim # upload limits; only set for dbv
+ self.nodes: dict[str, VFS] = {} # child nodes
+ self.histtab: dict[str, str] = {} # all realpath->histpath
+ self.dbv: Optional[VFS] = None # closest full/non-jump parent
+ self.lim: Optional[Lim] = None # upload limits; only set for dbv
+ self.aread: dict[str, list[str]] = {}
+ self.awrite: dict[str, list[str]] = {}
+ self.amove: dict[str, list[str]] = {}
+ self.adel: dict[str, list[str]] = {}
+ self.aget: dict[str, list[str]] = {}
if realpath:
self.histpath = os.path.join(realpath, ".hist") # db / thumbcache
self.all_vols = {vpath: self} # flattened recursive
- self.aread = {}
- self.awrite = {}
- self.amove = {}
- self.adel = {}
- self.aget = {}
else:
- self.histpath = None
- self.all_vols = None
- self.aread = None
- self.awrite = None
- self.amove = None
- self.adel = None
- self.aget = None
+ self.histpath = ""
+ self.all_vols = {}
- def __repr__(self):
+ def __repr__(self) -> str:
return "VFS({})".format(
", ".join(
"{}={!r}".format(k, self.__dict__[k])
@@ -236,14 +262,14 @@ class VFS(object):
)
)
- def get_all_vols(self, outdict):
+ def get_all_vols(self, outdict: dict[str, "VFS"]) -> None:
if self.realpath:
outdict[self.vpath] = self
for v in self.nodes.values():
v.get_all_vols(outdict)
- def add(self, src, dst):
+ def add(self, src: str, dst: str) -> "VFS":
"""get existing, or add new path to the vfs"""
assert not src.endswith("/") # nosec
assert not dst.endswith("/") # nosec
@@ -257,7 +283,7 @@ class VFS(object):
vn = VFS(
self.log,
- os.path.join(self.realpath, name) if self.realpath else None,
+ os.path.join(self.realpath, name) if self.realpath else "",
"{}/{}".format(self.vpath, name).lstrip("/"),
self.axs,
self._copy_flags(name),
@@ -277,7 +303,7 @@ class VFS(object):
self.nodes[dst] = vn
return vn
- def _copy_flags(self, name):
+ def _copy_flags(self, name: str) -> dict[str, Any]:
flags = {k: v for k, v in self.flags.items()}
hist = flags.get("hist")
if hist and hist != "-":
@@ -285,20 +311,20 @@ class VFS(object):
return flags
- def bubble_flags(self):
+ def bubble_flags(self) -> None:
if self.dbv:
for k, v in self.dbv.flags.items():
if k not in ["hist"]:
self.flags[k] = v
- for v in self.nodes.values():
- v.bubble_flags()
+ for n in self.nodes.values():
+ n.bubble_flags()
- def _find(self, vpath):
+ def _find(self, vpath: str) -> tuple["VFS", str]:
"""return [vfs,remainder]"""
vpath = undot(vpath)
if vpath == "":
- return [self, ""]
+ return self, ""
if "/" in vpath:
name, rem = vpath.split("/", 1)
@@ -309,66 +335,64 @@ class VFS(object):
if name in self.nodes:
return self.nodes[name]._find(rem)
- return [self, vpath]
+ return self, vpath
- def can_access(self, vpath, uname):
- # type: (str, str) -> tuple[bool, bool, bool, bool]
+ def can_access(self, vpath: str, uname: str) -> tuple[bool, bool, bool, bool, bool]:
"""can Read,Write,Move,Delete,Get"""
vn, _ = self._find(vpath)
c = vn.axs
- return [
+ return (
uname in c.uread or "*" in c.uread,
uname in c.uwrite or "*" in c.uwrite,
uname in c.umove or "*" in c.umove,
uname in c.udel or "*" in c.udel,
uname in c.uget or "*" in c.uget,
- ]
+ )
def get(
self,
- vpath,
- uname,
- will_read,
- will_write,
- will_move=False,
- will_del=False,
- will_get=False,
- ):
- # type: (str, str, bool, bool, bool, bool, bool) -> tuple[VFS, str]
+ vpath: str,
+ uname: str,
+ will_read: bool,
+ will_write: bool,
+ will_move: bool = False,
+ will_del: bool = False,
+ will_get: bool = False,
+ ) -> tuple["VFS", str]:
"""returns [vfsnode,fs_remainder] if user has the requested permissions"""
if ANYWIN:
mod = relchk(vpath)
if mod:
- self.log("vfs", "invalid relpath [{}]".format(vpath))
+ if self.log:
+ self.log("vfs", "invalid relpath [{}]".format(vpath))
raise Pebkac(404)
vn, rem = self._find(vpath)
- c = vn.axs
+ c: AXS = vn.axs
for req, d, msg in [
- [will_read, c.uread, "read"],
- [will_write, c.uwrite, "write"],
- [will_move, c.umove, "move"],
- [will_del, c.udel, "delete"],
- [will_get, c.uget, "get"],
+ (will_read, c.uread, "read"),
+ (will_write, c.uwrite, "write"),
+ (will_move, c.umove, "move"),
+ (will_del, c.udel, "delete"),
+ (will_get, c.uget, "get"),
]:
if req and (uname not in d and "*" not in d) and uname != LEELOO_DALLAS:
- m = "you don't have {}-access for this location"
- raise Pebkac(403, m.format(msg))
+ t = "you don't have {}-access for this location"
+ raise Pebkac(403, t.format(msg))
return vn, rem
- def get_dbv(self, vrem):
- # type: (str) -> tuple[VFS, str]
+ def get_dbv(self, vrem: str) -> tuple["VFS", str]:
dbv = self.dbv
if not dbv:
return self, vrem
- vrem = [self.vpath[len(dbv.vpath) :].lstrip("/"), vrem]
- vrem = "/".join([x for x in vrem if x])
+ tv = [self.vpath[len(dbv.vpath) :].lstrip("/"), vrem]
+ vrem = "/".join([x for x in tv if x])
return dbv, vrem
- def canonical(self, rem, resolve=True):
+ def canonical(self, rem: str, resolve: bool = True) -> str:
"""returns the canonical path (fully-resolved absolute fs path)"""
rp = self.realpath
if rem:
@@ -376,8 +400,14 @@ class VFS(object):
return absreal(rp) if resolve else rp
- def ls(self, rem, uname, scandir, permsets, lstat=False):
- # type: (str, str, bool, list[list[bool]], bool) -> tuple[str, str, dict[str, VFS]]
+ def ls(
+ self,
+ rem: str,
+ uname: str,
+ scandir: bool,
+ permsets: list[list[bool]],
+ lstat: bool = False,
+ ) -> tuple[str, list[tuple[str, os.stat_result]], dict[str, "VFS"]]:
"""return user-readable [fsdir,real,virt] items at vpath"""
virt_vis = {} # nodes readable by user
abspath = self.canonical(rem)
@@ -389,8 +419,8 @@ class VFS(object):
for name, vn2 in sorted(self.nodes.items()):
ok = False
- axs = vn2.axs
- axs = [axs.uread, axs.uwrite, axs.umove, axs.udel, axs.uget]
+ zx = vn2.axs
+ axs = [zx.uread, zx.uwrite, zx.umove, zx.udel, zx.uget]
for pset in permsets:
ok = True
for req, lst in zip(pset, axs):
@@ -409,9 +439,32 @@ class VFS(object):
elif "/.hist/th/" in p:
real = [x for x in real if not x[0].endswith("dir.txt")]
- return [abspath, real, virt_vis]
+ return abspath, real, virt_vis
- def walk(self, rel, rem, seen, uname, permsets, dots, scandir, lstat, subvols=True):
+ def walk(
+ self,
+ rel: str,
+ rem: str,
+ seen: list[str],
+ uname: str,
+ permsets: list[list[bool]],
+ dots: bool,
+ scandir: bool,
+ lstat: bool,
+ subvols: bool = True,
+ ) -> Generator[
+ tuple[
+ "VFS",
+ str,
+ str,
+ str,
+ list[tuple[str, os.stat_result]],
+ list[tuple[str, os.stat_result]],
+ dict[str, "VFS"],
+ ],
+ None,
+ None,
+ ]:
"""
recursively yields from ./rem;
rel is a unix-style user-defined vpath (not vfs-related)
@@ -425,8 +478,9 @@ class VFS(object):
and (not fsroot.startswith(seen[-1]) or fsroot == seen[-1])
and fsroot in seen
):
- m = "bailing from symlink loop,\n prev: {}\n curr: {}\n from: {}/{}"
- self.log("vfs.walk", m.format(seen[-1], fsroot, self.vpath, rem), 3)
+ if self.log:
+ t = "bailing from symlink loop,\n prev: {}\n curr: {}\n from: {}/{}"
+ self.log("vfs.walk", t.format(seen[-1], fsroot, self.vpath, rem), 3)
return
seen = seen[:] + [fsroot]
@@ -460,9 +514,9 @@ class VFS(object):
for x in vfs.walk(wrel, "", seen, uname, permsets, dots, scandir, lstat):
yield x
- def zipgen(self, vrem, flt, uname, dots, scandir):
- if flt:
- flt = {k: True for k in flt}
+ def zipgen(
+ self, vrem: str, flt: set[str], uname: str, dots: bool, scandir: bool
+ ) -> Generator[dict[str, Any], None, None]:
# if multiselect: add all items to archive root
# if single folder: the folder itself is the top-level item
@@ -473,33 +527,33 @@ class VFS(object):
if flt:
files = [x for x in files if x[0] in flt]
- rm = [x for x in rd if x[0] not in flt]
- [rd.remove(x) for x in rm]
+ rm1 = [x for x in rd if x[0] not in flt]
+ _ = [rd.remove(x) for x in rm1] # type: ignore
- rm = [x for x in vd.keys() if x not in flt]
- [vd.pop(x) for x in rm]
+ rm2 = [x for x in vd.keys() if x not in flt]
+ _ = [vd.pop(x) for x in rm2]
- flt = None
+ flt = set()
# print(repr([vpath, apath, [x[0] for x in files]]))
fnames = [n[0] for n in files]
vpaths = [vpath + "/" + n for n in fnames] if vpath else fnames
apaths = [os.path.join(apath, n) for n in fnames]
- files = list(zip(vpaths, apaths, files))
+ ret = list(zip(vpaths, apaths, files))
if not dots:
# dotfile filtering based on vpath (intended visibility)
- files = [x for x in files if "/." not in "/" + x[0]]
+ ret = [x for x in ret if "/." not in "/" + x[0]]
- rm = [x for x in rd if x[0].startswith(".")]
- for x in rm:
- rd.remove(x)
+ zel = [ze for ze in rd if ze[0].startswith(".")]
+ for ze in zel:
+ rd.remove(ze)
- rm = [k for k in vd.keys() if k.startswith(".")]
- for x in rm:
- del vd[x]
+ zsl = [zs for zs in vd.keys() if zs.startswith(".")]
+ for zs in zsl:
+ del vd[zs]
- for f in [{"vp": v, "ap": a, "st": n[1]} for v, a, n in files]:
+ for f in [{"vp": v, "ap": a, "st": n[1]} for v, a, n in ret]:
yield f
@@ -512,7 +566,12 @@ else:
class AuthSrv(object):
"""verifies users against given paths"""
- def __init__(self, args, log_func, warn_anonwrite=True):
+ def __init__(
+ self,
+ args: argparse.Namespace,
+ log_func: Optional[RootLogger],
+ warn_anonwrite: bool = True,
+ ) -> None:
self.args = args
self.log_func = log_func
self.warn_anonwrite = warn_anonwrite
@@ -521,11 +580,11 @@ class AuthSrv(object):
self.mutex = threading.Lock()
self.reload()
- def log(self, msg, c=0):
+ def log(self, msg: str, c: Union[int, str] = 0) -> None:
if self.log_func:
self.log_func("auth", msg, c)
- def laggy_iter(self, iterable):
+ def laggy_iter(self, iterable: Iterable[Any]) -> Generator[Any, None, None]:
"""returns [value,isFinalValue]"""
it = iter(iterable)
prev = next(it)
@@ -535,26 +594,39 @@ class AuthSrv(object):
yield prev, True
- def _map_volume(self, src, dst, mount, daxs, mflags):
+ def _map_volume(
+ self,
+ src: str,
+ dst: str,
+ mount: dict[str, str],
+ daxs: dict[str, AXS],
+ mflags: dict[str, dict[str, Any]],
+ ) -> None:
if dst in mount:
- m = "multiple filesystem-paths mounted at [/{}]:\n [{}]\n [{}]"
- self.log(m.format(dst, mount[dst], src), c=1)
+ t = "multiple filesystem-paths mounted at [/{}]:\n [{}]\n [{}]"
+ self.log(t.format(dst, mount[dst], src), c=1)
raise Exception("invalid config")
if src in mount.values():
- m = "warning: filesystem-path [{}] mounted in multiple locations:"
- m = m.format(src)
+ t = "warning: filesystem-path [{}] mounted in multiple locations:"
+ t = t.format(src)
for v in [k for k, v in mount.items() if v == src] + [dst]:
- m += "\n /{}".format(v)
+ t += "\n /{}".format(v)
- self.log(m, c=3)
+ self.log(t, c=3)
mount[dst] = src
daxs[dst] = AXS()
mflags[dst] = {}
- def _parse_config_file(self, fd, acct, daxs, mflags, mount):
- # type: (any, str, dict[str, AXS], any, str) -> None
+ def _parse_config_file(
+ self,
+ fd: typing.BinaryIO,
+ acct: dict[str, str],
+ daxs: dict[str, AXS],
+ mflags: dict[str, dict[str, Any]],
+ mount: dict[str, str],
+ ) -> None:
skip = False
vol_src = None
vol_dst = None
@@ -601,23 +673,25 @@ class AuthSrv(object):
uname = "*"
if lvl == "a":
- m = "WARNING (config-file): permission flag 'a' is deprecated; please use 'rw' instead"
- self.log(m, 1)
+ t = "WARNING (config-file): permission flag 'a' is deprecated; please use 'rw' instead"
+ self.log(t, 1)
self._read_vol_str(lvl, uname, daxs[vol_dst], mflags[vol_dst])
- def _read_vol_str(self, lvl, uname, axs, flags):
- # type: (str, str, AXS, any) -> None
+ def _read_vol_str(
+ self, lvl: str, uname: str, axs: AXS, flags: dict[str, Any]
+ ) -> None:
if lvl.strip("crwmdg"):
raise Exception("invalid volume flag: {},{}".format(lvl, uname))
if lvl == "c":
+ cval: Union[bool, str] = True
try:
# volume flag with arguments, possibly with a preceding list of bools
uname, cval = uname.split("=", 1)
except:
# just one or more bools
- cval = True
+ pass
while "," in uname:
# one or more bools before the final flag; eat them
@@ -631,34 +705,38 @@ class AuthSrv(object):
uname = "*"
for un in uname.replace(",", " ").strip().split():
- if "r" in lvl:
- axs.uread[un] = 1
+ for ch, al in [
+ ("r", axs.uread),
+ ("w", axs.uwrite),
+ ("m", axs.umove),
+ ("d", axs.udel),
+ ("g", axs.uget),
+ ]:
+ if ch in lvl:
+ al.add(un)
- if "w" in lvl:
- axs.uwrite[un] = 1
-
- if "m" in lvl:
- axs.umove[un] = 1
-
- if "d" in lvl:
- axs.udel[un] = 1
-
- if "g" in lvl:
- axs.uget[un] = 1
-
- def _read_volflag(self, flags, name, value, is_list):
+ def _read_volflag(
+ self,
+ flags: dict[str, Any],
+ name: str,
+ value: Union[str, bool, list[str]],
+ is_list: bool,
+ ) -> None:
if name not in ["mtp"]:
flags[name] = value
return
- if not is_list:
- value = [value]
- elif not value:
+ vals = flags.get(name, [])
+ if not value:
return
+ elif is_list:
+ vals += value
+ else:
+ vals += [value]
- flags[name] = flags.get(name, []) + value
+ flags[name] = vals
- def reload(self):
+ def reload(self) -> None:
"""
construct a flat list of mountpoints and usernames
first from the commandline arguments
@@ -666,10 +744,10 @@ class AuthSrv(object):
before finally building the VFS
"""
- acct = {} # username:password
- daxs = {} # type: dict[str, AXS]
- mflags = {} # mountpoint:[flag]
- mount = {} # dst:src (mountpoint:realpath)
+ acct: dict[str, str] = {} # username:password
+ daxs: dict[str, AXS] = {}
+ mflags: dict[str, dict[str, Any]] = {} # moutpoint:flags
+ mount: dict[str, str] = {} # dst:src (mountpoint:realpath)
if self.args.a:
# list of username:password
@@ -678,8 +756,8 @@ class AuthSrv(object):
u, p = x.split(":", 1)
acct[u] = p
except:
- m = '\n invalid value "{}" for argument -a, must be username:password'
- raise Exception(m.format(x))
+ t = '\n invalid value "{}" for argument -a, must be username:password'
+ raise Exception(t.format(x))
if self.args.v:
# list of src:dst:permset:permset:...
@@ -708,8 +786,8 @@ class AuthSrv(object):
try:
self._parse_config_file(f, acct, daxs, mflags, mount)
except:
- m = "\n\033[1;31m\nerror in config file {} on line {}:\n\033[0m"
- self.log(m.format(cfg_fn, self.line_ctr), 1)
+ t = "\n\033[1;31m\nerror in config file {} on line {}:\n\033[0m"
+ self.log(t.format(cfg_fn, self.line_ctr), 1)
raise
# case-insensitive; normalize
@@ -726,7 +804,7 @@ class AuthSrv(object):
vfs = VFS(self.log_func, bos.path.abspath("."), "", axs, {})
elif "" not in mount:
# there's volumes but no root; make root inaccessible
- vfs = VFS(self.log_func, None, "", AXS(), {})
+ vfs = VFS(self.log_func, "", "", AXS(), {})
vfs.flags["d2d"] = True
maxdepth = 0
@@ -740,10 +818,10 @@ class AuthSrv(object):
vfs = VFS(self.log_func, mount[dst], dst, daxs[dst], mflags[dst])
continue
- v = vfs.add(mount[dst], dst)
- v.axs = daxs[dst]
- v.flags = mflags[dst]
- v.dbv = None
+ zv = vfs.add(mount[dst], dst)
+ zv.axs = daxs[dst]
+ zv.flags = mflags[dst]
+ zv.dbv = None
vfs.all_vols = {}
vfs.get_all_vols(vfs.all_vols)
@@ -751,11 +829,11 @@ class AuthSrv(object):
for perm in "read write move del get".split():
axs_key = "u" + perm
unames = ["*"] + list(acct.keys())
- umap = {x: [] for x in unames}
+ umap: dict[str, list[str]] = {x: [] for x in unames}
for usr in unames:
for vp, vol in vfs.all_vols.items():
- axs = getattr(vol.axs, axs_key)
- if usr in axs or "*" in axs:
+ zx = getattr(vol.axs, axs_key)
+ if usr in zx or "*" in zx:
umap[usr].append(vp)
umap[usr].sort()
setattr(vfs, "a" + perm, umap)
@@ -764,7 +842,7 @@ class AuthSrv(object):
missing_users = {}
for axs in daxs.values():
for d in [axs.uread, axs.uwrite, axs.umove, axs.udel, axs.uget]:
- for usr in d.keys():
+ for usr in d:
all_users[usr] = 1
if usr != "*" and usr not in acct:
missing_users[usr] = 1
@@ -783,8 +861,8 @@ class AuthSrv(object):
promote = []
demote = []
for vol in vfs.all_vols.values():
- hid = hashlib.sha512(fsenc(vol.realpath)).digest()
- hid = base64.b32encode(hid).decode("ascii").lower()
+ zb = hashlib.sha512(fsenc(vol.realpath)).digest()
+ hid = base64.b32encode(zb).decode("ascii").lower()
vflag = vol.flags.get("hist")
if vflag == "-":
pass
@@ -822,21 +900,21 @@ class AuthSrv(object):
demote.append(vol)
# discard jump-vols
- for v in demote:
- vfs.all_vols.pop(v.vpath)
+ for zv in demote:
+ vfs.all_vols.pop(zv.vpath)
if promote:
- msg = [
+ ta = [
"\n the following jump-volumes were generated to assist the vfs.\n As they contain a database (probably from v0.11.11 or older),\n they are promoted to full volumes:"
]
for vol in promote:
- msg.append(
+ ta.append(
" /{} ({}) ({})".format(vol.vpath, vol.realpath, vol.histpath)
)
- self.log("\n\n".join(msg) + "\n", c=3)
+ self.log("\n\n".join(ta) + "\n", c=3)
- vfs.histtab = {v.realpath: v.histpath for v in vfs.all_vols.values()}
+ vfs.histtab = {zv.realpath: zv.histpath for zv in vfs.all_vols.values()}
for vol in vfs.all_vols.values():
lim = Lim()
@@ -846,30 +924,30 @@ class AuthSrv(object):
use = True
lim.nosub = True
- v = vol.flags.get("sz")
- if v:
+ zs = vol.flags.get("sz")
+ if zs:
use = True
- lim.smin, lim.smax = [unhumanize(x) for x in v.split("-")]
+ lim.smin, lim.smax = [unhumanize(x) for x in zs.split("-")]
- v = vol.flags.get("rotn")
- if v:
+ zs = vol.flags.get("rotn")
+ if zs:
use = True
- lim.rotn, lim.rotl = [int(x) for x in v.split(",")]
+ lim.rotn, lim.rotl = [int(x) for x in zs.split(",")]
- v = vol.flags.get("rotf")
- if v:
+ zs = vol.flags.get("rotf")
+ if zs:
use = True
- lim.set_rotf(v)
+ lim.set_rotf(zs)
- v = vol.flags.get("maxn")
- if v:
+ zs = vol.flags.get("maxn")
+ if zs:
use = True
- lim.nmax, lim.nwin = [int(x) for x in v.split(",")]
+ lim.nmax, lim.nwin = [int(x) for x in zs.split(",")]
- v = vol.flags.get("maxb")
- if v:
+ zs = vol.flags.get("maxb")
+ if zs:
use = True
- lim.bmax, lim.bwin = [unhumanize(x) for x in v.split(",")]
+ lim.bmax, lim.bwin = [unhumanize(x) for x in zs.split(",")]
if use:
vol.lim = lim
@@ -1005,8 +1083,8 @@ class AuthSrv(object):
for mtp in local_only_mtp:
if mtp not in local_mte:
- m = 'volume "/{}" defines metadata tag "{}", but doesnt use it in "-mte" (or with "cmte" in its volume-flags)'
- self.log(m.format(vol.vpath, mtp), 1)
+ t = 'volume "/{}" defines metadata tag "{}", but doesnt use it in "-mte" (or with "cmte" in its volume-flags)'
+ self.log(t.format(vol.vpath, mtp), 1)
errors = True
tags = self.args.mtp or []
@@ -1014,8 +1092,8 @@ class AuthSrv(object):
tags = [y for x in tags for y in x.split(",")]
for mtp in tags:
if mtp not in all_mte:
- m = 'metadata tag "{}" is defined by "-mtm" or "-mtp", but is not used by "-mte" (or by any "cmte" volume-flag)'
- self.log(m.format(mtp), 1)
+ t = 'metadata tag "{}" is defined by "-mtm" or "-mtp", but is not used by "-mte" (or by any "cmte" volume-flag)'
+ self.log(t.format(mtp), 1)
errors = True
if errors:
@@ -1023,12 +1101,12 @@ class AuthSrv(object):
vfs.bubble_flags()
- m = "volumes and permissions:\n"
- for v in vfs.all_vols.values():
+ t = "volumes and permissions:\n"
+ for zv in vfs.all_vols.values():
if not self.warn_anonwrite:
break
- m += '\n\033[36m"/{}" \033[33m{}\033[0m'.format(v.vpath, v.realpath)
+ t += '\n\033[36m"/{}" \033[33m{}\033[0m'.format(zv.vpath, zv.realpath)
for txt, attr in [
[" read", "uread"],
[" write", "uwrite"],
@@ -1036,21 +1114,21 @@ class AuthSrv(object):
["delete", "udel"],
[" get", "uget"],
]:
- u = list(sorted(getattr(v.axs, attr).keys()))
+ u = list(sorted(getattr(zv.axs, attr)))
u = ", ".join("\033[35meverybody\033[0m" if x == "*" else x for x in u)
u = u if u else "\033[36m--none--\033[0m"
- m += "\n| {}: {}".format(txt, u)
- m += "\n"
+ t += "\n| {}: {}".format(txt, u)
+ t += "\n"
if self.warn_anonwrite and not self.args.no_voldump:
- self.log(m)
+ self.log(t)
try:
- v, _ = vfs.get("/", "*", False, True)
- if self.warn_anonwrite and os.getcwd() == v.realpath:
+ zv, _ = vfs.get("/", "*", False, True)
+ if self.warn_anonwrite and os.getcwd() == zv.realpath:
self.warn_anonwrite = False
- msg = "anyone can read/write the current directory: {}\n"
- self.log(msg.format(v.realpath), c=1)
+ t = "anyone can read/write the current directory: {}\n"
+ self.log(t.format(zv.realpath), c=1)
except Pebkac:
self.warn_anonwrite = True
@@ -1064,19 +1142,19 @@ class AuthSrv(object):
if pwds:
self.re_pwd = re.compile("=(" + "|".join(pwds) + ")([]&; ]|$)")
- def dbg_ls(self):
+ def dbg_ls(self) -> None:
users = self.args.ls
- vols = "*"
- flags = []
+ vol = "*"
+ flags: list[str] = []
try:
- users, vols = users.split(",", 1)
+ users, vol = users.split(",", 1)
except:
pass
try:
- vols, flags = vols.split(",", 1)
- flags = flags.split(",")
+ vol, zf = vol.split(",", 1)
+ flags = zf.split(",")
except:
pass
@@ -1089,23 +1167,23 @@ class AuthSrv(object):
if u not in self.acct and u != "*":
raise Exception("user not found: " + u)
- if vols == "*":
+ if vol == "*":
vols = ["/" + x for x in self.vfs.all_vols]
else:
- vols = [vols]
+ vols = [vol]
- for v in vols:
- if not v.startswith("/"):
+ for zs in vols:
+ if not zs.startswith("/"):
raise Exception("volumes must start with /")
- if v[1:] not in self.vfs.all_vols:
- raise Exception("volume not found: " + v)
+ if zs[1:] not in self.vfs.all_vols:
+ raise Exception("volume not found: " + zs)
- self.log({"users": users, "vols": vols, "flags": flags})
- m = "/{}: read({}) write({}) move({}) del({}) get({})"
- for k, v in self.vfs.all_vols.items():
- vc = v.axs
- self.log(m.format(k, vc.uread, vc.uwrite, vc.umove, vc.udel, vc.uget))
+ self.log(str({"users": users, "vols": vols, "flags": flags}))
+ t = "/{}: read({}) write({}) move({}) del({}) get({})"
+ for k, zv in self.vfs.all_vols.items():
+ vc = zv.axs
+ self.log(t.format(k, vc.uread, vc.uwrite, vc.umove, vc.udel, vc.uget))
flag_v = "v" in flags
flag_ln = "ln" in flags
@@ -1136,12 +1214,14 @@ class AuthSrv(object):
False,
False,
)
- for _, _, vpath, apath, files, dirs, _ in g:
- fnames = [n[0] for n in files]
- vpaths = [vpath + "/" + n for n in fnames] if vpath else fnames
- vpaths = [vtop + x for x in vpaths]
+ for _, _, vpath, apath, files1, dirs, _ in g:
+ fnames = [n[0] for n in files1]
+ zsl = [vpath + "/" + n for n in fnames] if vpath else fnames
+ vpaths = [vtop + x for x in zsl]
apaths = [os.path.join(apath, n) for n in fnames]
- files = [[vpath + "/", apath + os.sep]] + list(zip(vpaths, apaths))
+ files = [(vpath + "/", apath + os.sep)] + list(
+ [(zs1, zs2) for zs1, zs2 in zip(vpaths, apaths)]
+ )
if flag_ln:
files = [x for x in files if not x[1].startswith(safeabs)]
@@ -1152,21 +1232,23 @@ class AuthSrv(object):
if not files:
continue
elif flag_v:
- msg = [""] + [
+ ta = [""] + [
'# user "{}", vpath "{}"\n{}'.format(u, vp, ap)
for vp, ap in files
]
else:
- msg = ["user {}, vol {}: {} =>".format(u, vtop, files[0][0])]
- msg += [x[1] for x in files]
+ ta = ["user {}, vol {}: {} =>".format(u, vtop, files[0][0])]
+ ta += [x[1] for x in files]
- self.log("\n".join(msg))
+ self.log("\n".join(ta))
if bads:
self.log("\n ".join(["found symlinks leaving volume:"] + bads))
if bads and flag_p:
- raise Exception("found symlink leaving volume, and strict is set")
+ raise Exception(
+ "\033[31m\n [--ls] found a safety issue and prevented startup:\n found symlinks leaving volume, and strict is set\n\033[0m"
+ )
if not flag_r:
sys.exit(0)
diff --git a/copyparty/bos/bos.py b/copyparty/bos/bos.py
index d5e003cf..617545af 100644
--- a/copyparty/bos/bos.py
+++ b/copyparty/bos/bos.py
@@ -2,23 +2,30 @@
from __future__ import print_function, unicode_literals
import os
-from ..util import fsenc, fsdec, SYMTIME
+
+from ..util import SYMTIME, fsdec, fsenc
from . import path
+try:
+ from typing import Optional
+except:
+ pass
+
+_ = (path,)
# grep -hRiE '(^|[^a-zA-Z_\.-])os\.' . | gsed -r 's/ /\n/g;s/\(/(\n/g' | grep -hRiE '(^|[^a-zA-Z_\.-])os\.' | sort | uniq -c
# printf 'os\.(%s)' "$(grep ^def bos/__init__.py | gsed -r 's/^def //;s/\(.*//' | tr '\n' '|' | gsed -r 's/.$//')"
-def chmod(p, mode):
+def chmod(p: str, mode: int) -> None:
return os.chmod(fsenc(p), mode)
-def listdir(p="."):
+def listdir(p: str = ".") -> list[str]:
return [fsdec(x) for x in os.listdir(fsenc(p))]
-def makedirs(name, mode=0o755, exist_ok=True):
+def makedirs(name: str, mode: int = 0o755, exist_ok: bool = True) -> None:
bname = fsenc(name)
try:
os.makedirs(bname, mode)
@@ -27,31 +34,33 @@ def makedirs(name, mode=0o755, exist_ok=True):
raise
-def mkdir(p, mode=0o755):
+def mkdir(p: str, mode: int = 0o755) -> None:
return os.mkdir(fsenc(p), mode)
-def rename(src, dst):
+def rename(src: str, dst: str) -> None:
return os.rename(fsenc(src), fsenc(dst))
-def replace(src, dst):
+def replace(src: str, dst: str) -> None:
return os.replace(fsenc(src), fsenc(dst))
-def rmdir(p):
+def rmdir(p: str) -> None:
return os.rmdir(fsenc(p))
-def stat(p):
+def stat(p: str) -> os.stat_result:
return os.stat(fsenc(p))
-def unlink(p):
+def unlink(p: str) -> None:
return os.unlink(fsenc(p))
-def utime(p, times=None, follow_symlinks=True):
+def utime(
+ p: str, times: Optional[tuple[float, float]] = None, follow_symlinks: bool = True
+) -> None:
if SYMTIME:
return os.utime(fsenc(p), times, follow_symlinks=follow_symlinks)
else:
@@ -60,7 +69,7 @@ def utime(p, times=None, follow_symlinks=True):
if hasattr(os, "lstat"):
- def lstat(p):
+ def lstat(p: str) -> os.stat_result:
return os.lstat(fsenc(p))
else:
diff --git a/copyparty/bos/path.py b/copyparty/bos/path.py
index 066453b0..c5769d84 100644
--- a/copyparty/bos/path.py
+++ b/copyparty/bos/path.py
@@ -2,43 +2,44 @@
from __future__ import print_function, unicode_literals
import os
-from ..util import fsenc, fsdec, SYMTIME
+
+from ..util import SYMTIME, fsdec, fsenc
-def abspath(p):
+def abspath(p: str) -> str:
return fsdec(os.path.abspath(fsenc(p)))
-def exists(p):
+def exists(p: str) -> bool:
return os.path.exists(fsenc(p))
-def getmtime(p, follow_symlinks=True):
+def getmtime(p: str, follow_symlinks: bool = True) -> float:
if not follow_symlinks and SYMTIME:
return os.lstat(fsenc(p)).st_mtime
else:
return os.path.getmtime(fsenc(p))
-def getsize(p):
+def getsize(p: str) -> int:
return os.path.getsize(fsenc(p))
-def isfile(p):
+def isfile(p: str) -> bool:
return os.path.isfile(fsenc(p))
-def isdir(p):
+def isdir(p: str) -> bool:
return os.path.isdir(fsenc(p))
-def islink(p):
+def islink(p: str) -> bool:
return os.path.islink(fsenc(p))
-def lexists(p):
+def lexists(p: str) -> bool:
return os.path.lexists(fsenc(p))
-def realpath(p):
+def realpath(p: str) -> str:
return fsdec(os.path.realpath(fsenc(p)))
diff --git a/copyparty/broker_mp.py b/copyparty/broker_mp.py
index df73658d..c7bfed70 100644
--- a/copyparty/broker_mp.py
+++ b/copyparty/broker_mp.py
@@ -1,37 +1,56 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import time
import threading
+import time
-from .broker_util import try_exec
+import queue
+
+from .__init__ import TYPE_CHECKING
from .broker_mpw import MpWorker
+from .broker_util import try_exec
from .util import mp
+if TYPE_CHECKING:
+ from .svchub import SvcHub
+
+try:
+ from typing import Any
+except:
+ pass
+
+
+class MProcess(mp.Process):
+ def __init__(
+ self,
+ q_pend: queue.Queue[tuple[int, str, list[Any]]],
+ q_yield: queue.Queue[tuple[int, str, list[Any]]],
+ target: Any,
+ args: Any,
+ ) -> None:
+ super(MProcess, self).__init__(target=target, args=args)
+ self.q_pend = q_pend
+ self.q_yield = q_yield
+
class BrokerMp(object):
"""external api; manages MpWorkers"""
- def __init__(self, hub):
+ def __init__(self, hub: "SvcHub") -> None:
self.hub = hub
self.log = hub.log
self.args = hub.args
self.procs = []
- self.retpend = {}
- self.retpend_mutex = threading.Lock()
self.mutex = threading.Lock()
self.num_workers = self.args.j or mp.cpu_count()
self.log("broker", "booting {} subprocesses".format(self.num_workers))
for n in range(1, self.num_workers + 1):
- q_pend = mp.Queue(1)
- q_yield = mp.Queue(64)
+ q_pend: queue.Queue[tuple[int, str, list[Any]]] = mp.Queue(1)
+ q_yield: queue.Queue[tuple[int, str, list[Any]]] = mp.Queue(64)
- proc = mp.Process(target=MpWorker, args=(q_pend, q_yield, self.args, n))
- proc.q_pend = q_pend
- proc.q_yield = q_yield
- proc.clients = {}
+ proc = MProcess(q_pend, q_yield, MpWorker, (q_pend, q_yield, self.args, n))
thr = threading.Thread(
target=self.collector, args=(proc,), name="mp-sink-{}".format(n)
@@ -42,11 +61,11 @@ class BrokerMp(object):
self.procs.append(proc)
proc.start()
- def shutdown(self):
+ def shutdown(self) -> None:
self.log("broker", "shutting down")
for n, proc in enumerate(self.procs):
thr = threading.Thread(
- target=proc.q_pend.put([0, "shutdown", []]),
+ target=proc.q_pend.put((0, "shutdown", [])),
name="mp-shutdown-{}-{}".format(n, len(self.procs)),
)
thr.start()
@@ -62,12 +81,12 @@ class BrokerMp(object):
procs.pop()
- def reload(self):
+ def reload(self) -> None:
self.log("broker", "reloading")
for _, proc in enumerate(self.procs):
- proc.q_pend.put([0, "reload", []])
+ proc.q_pend.put((0, "reload", []))
- def collector(self, proc):
+ def collector(self, proc: MProcess) -> None:
"""receive message from hub in other process"""
while True:
msg = proc.q_yield.get()
@@ -78,10 +97,7 @@ class BrokerMp(object):
elif dest == "retq":
# response from previous ipc call
- with self.retpend_mutex:
- retq = self.retpend.pop(retq_id)
-
- retq.put(args)
+ raise Exception("invalid broker_mp usage")
else:
# new ipc invoking managed service in hub
@@ -93,9 +109,9 @@ class BrokerMp(object):
rv = try_exec(retq_id, obj, *args)
if retq_id:
- proc.q_pend.put([retq_id, "retq", rv])
+ proc.q_pend.put((retq_id, "retq", rv))
- def put(self, want_retval, dest, *args):
+ def say(self, dest: str, *args: Any) -> None:
"""
send message to non-hub component in other process,
returns a Queue object which eventually contains the response if want_retval
@@ -103,7 +119,7 @@ class BrokerMp(object):
"""
if dest == "listen":
for p in self.procs:
- p.q_pend.put([0, dest, [args[0], len(self.procs)]])
+ p.q_pend.put((0, dest, [args[0], len(self.procs)]))
elif dest == "cb_httpsrv_up":
self.hub.cb_httpsrv_up()
diff --git a/copyparty/broker_mpw.py b/copyparty/broker_mpw.py
index c4a1054c..4dbec5b1 100644
--- a/copyparty/broker_mpw.py
+++ b/copyparty/broker_mpw.py
@@ -1,20 +1,38 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import sys
+import argparse
import signal
+import sys
import threading
-from .broker_util import ExceptionalQueue
+import queue
+
+from .authsrv import AuthSrv
+from .broker_util import BrokerCli, ExceptionalQueue
from .httpsrv import HttpSrv
from .util import FAKE_MP
-from .authsrv import AuthSrv
+
+try:
+ from types import FrameType
+
+ from typing import Any, Optional, Union
+except:
+ pass
-class MpWorker(object):
+class MpWorker(BrokerCli):
"""one single mp instance"""
- def __init__(self, q_pend, q_yield, args, n):
+ def __init__(
+ self,
+ q_pend: queue.Queue[tuple[int, str, list[Any]]],
+ q_yield: queue.Queue[tuple[int, str, list[Any]]],
+ args: argparse.Namespace,
+ n: int,
+ ) -> None:
+ super(MpWorker, self).__init__()
+
self.q_pend = q_pend
self.q_yield = q_yield
self.args = args
@@ -22,7 +40,7 @@ class MpWorker(object):
self.log = self._log_disabled if args.q and not args.lo else self._log_enabled
- self.retpend = {}
+ self.retpend: dict[int, Any] = {}
self.retpend_mutex = threading.Lock()
self.mutex = threading.Lock()
@@ -45,20 +63,20 @@ class MpWorker(object):
thr.start()
thr.join()
- def signal_handler(self, sig, frame):
+ def signal_handler(self, sig: Optional[int], frame: Optional[FrameType]) -> None:
# print('k')
pass
- def _log_enabled(self, src, msg, c=0):
- self.q_yield.put([0, "log", [src, msg, c]])
+ def _log_enabled(self, src: str, msg: str, c: Union[int, str] = 0) -> None:
+ self.q_yield.put((0, "log", [src, msg, c]))
- def _log_disabled(self, src, msg, c=0):
+ def _log_disabled(self, src: str, msg: str, c: Union[int, str] = 0) -> None:
pass
- def logw(self, msg, c=0):
+ def logw(self, msg: str, c: Union[int, str] = 0) -> None:
self.log("mp{}".format(self.n), msg, c)
- def main(self):
+ def main(self) -> None:
while True:
retq_id, dest, args = self.q_pend.get()
@@ -87,15 +105,14 @@ class MpWorker(object):
else:
raise Exception("what is " + str(dest))
- def put(self, want_retval, dest, *args):
- if want_retval:
- retq = ExceptionalQueue(1)
- retq_id = id(retq)
- with self.retpend_mutex:
- self.retpend[retq_id] = retq
- else:
- retq = None
- retq_id = 0
+ def ask(self, dest: str, *args: Any) -> ExceptionalQueue:
+ retq = ExceptionalQueue(1)
+ retq_id = id(retq)
+ with self.retpend_mutex:
+ self.retpend[retq_id] = retq
- self.q_yield.put([retq_id, dest, args])
+ self.q_yield.put((retq_id, dest, list(args)))
return retq
+
+ def say(self, dest: str, *args: Any) -> None:
+ self.q_yield.put((0, dest, list(args)))
diff --git a/copyparty/broker_thr.py b/copyparty/broker_thr.py
index 1c7a1abf..51c25d41 100644
--- a/copyparty/broker_thr.py
+++ b/copyparty/broker_thr.py
@@ -3,14 +3,25 @@ from __future__ import print_function, unicode_literals
import threading
+from .__init__ import TYPE_CHECKING
+from .broker_util import BrokerCli, ExceptionalQueue, try_exec
from .httpsrv import HttpSrv
-from .broker_util import ExceptionalQueue, try_exec
+
+if TYPE_CHECKING:
+ from .svchub import SvcHub
+
+try:
+ from typing import Any
+except:
+ pass
-class BrokerThr(object):
+class BrokerThr(BrokerCli):
"""external api; behaves like BrokerMP but using plain threads"""
- def __init__(self, hub):
+ def __init__(self, hub: "SvcHub") -> None:
+ super(BrokerThr, self).__init__()
+
self.hub = hub
self.log = hub.log
self.args = hub.args
@@ -23,29 +34,35 @@ class BrokerThr(object):
self.httpsrv = HttpSrv(self, None)
self.reload = self.noop
- def shutdown(self):
+ def shutdown(self) -> None:
# self.log("broker", "shutting down")
self.httpsrv.shutdown()
- def noop(self):
+ def noop(self) -> None:
pass
- def put(self, want_retval, dest, *args):
+ def ask(self, dest: str, *args: Any) -> ExceptionalQueue:
+
+ # new ipc invoking managed service in hub
+ obj = self.hub
+ for node in dest.split("."):
+ obj = getattr(obj, node)
+
+ rv = try_exec(True, obj, *args)
+
+ # pretend we're broker_mp
+ retq = ExceptionalQueue(1)
+ retq.put(rv)
+ return retq
+
+ def say(self, dest: str, *args: Any) -> None:
if dest == "listen":
self.httpsrv.listen(args[0], 1)
+ return
- else:
- # new ipc invoking managed service in hub
- obj = self.hub
- for node in dest.split("."):
- obj = getattr(obj, node)
+ # new ipc invoking managed service in hub
+ obj = self.hub
+ for node in dest.split("."):
+ obj = getattr(obj, node)
- # TODO will deadlock if dest performs another ipc
- rv = try_exec(want_retval, obj, *args)
- if not want_retval:
- return
-
- # pretend we're broker_mp
- retq = ExceptionalQueue(1)
- retq.put(rv)
- return retq
+ try_exec(False, obj, *args)
diff --git a/copyparty/broker_util.py b/copyparty/broker_util.py
index 896cba2c..b0d44575 100644
--- a/copyparty/broker_util.py
+++ b/copyparty/broker_util.py
@@ -1,14 +1,28 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-
+import argparse
import traceback
-from .util import Pebkac, Queue
+from queue import Queue
+
+from .__init__ import TYPE_CHECKING
+from .authsrv import AuthSrv
+from .util import Pebkac
+
+try:
+ from typing import Any, Optional, Union
+
+ from .util import RootLogger
+except:
+ pass
+
+if TYPE_CHECKING:
+ from .httpsrv import HttpSrv
class ExceptionalQueue(Queue, object):
- def get(self, block=True, timeout=None):
+ def get(self, block: bool = True, timeout: Optional[float] = None) -> Any:
rv = super(ExceptionalQueue, self).get(block, timeout)
if isinstance(rv, list):
@@ -21,7 +35,26 @@ class ExceptionalQueue(Queue, object):
return rv
-def try_exec(want_retval, func, *args):
+class BrokerCli(object):
+ """
+ helps mypy understand httpsrv.broker but still fails a few levels deeper,
+ for example resolving httpconn.* in httpcli -- see lines tagged #mypy404
+ """
+
+ def __init__(self) -> None:
+ self.log: RootLogger = None
+ self.args: argparse.Namespace = None
+ self.asrv: AuthSrv = None
+ self.httpsrv: "HttpSrv" = None
+
+ def ask(self, dest: str, *args: Any) -> ExceptionalQueue:
+ return ExceptionalQueue(1)
+
+ def say(self, dest: str, *args: Any) -> None:
+ pass
+
+
+def try_exec(want_retval: Union[bool, int], func: Any, *args: list[Any]) -> Any:
try:
return func(*args)
diff --git a/copyparty/ftpd.py b/copyparty/ftpd.py
index 828c0db6..a47586de 100644
--- a/copyparty/ftpd.py
+++ b/copyparty/ftpd.py
@@ -1,23 +1,23 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import os
-import sys
-import stat
-import time
-import logging
import argparse
+import logging
+import os
+import stat
+import sys
import threading
+import time
-from pyftpdlib.authorizers import DummyAuthorizer, AuthenticationFailed
+from pyftpdlib.authorizers import AuthenticationFailed, DummyAuthorizer
from pyftpdlib.filesystems import AbstractedFS, FilesystemError
from pyftpdlib.handlers import FTPHandler
-from pyftpdlib.servers import FTPServer
from pyftpdlib.log import config_logging
+from pyftpdlib.servers import FTPServer
-from .__init__ import E, PY2
-from .util import Pebkac, fsenc, exclude_dotfiles
+from .__init__ import PY2, TYPE_CHECKING, E
from .bos import bos
+from .util import Pebkac, exclude_dotfiles, fsenc
try:
from pyftpdlib.ioloop import IOLoop
@@ -28,58 +28,63 @@ except ImportError:
from pyftpdlib.ioloop import IOLoop
-try:
- from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from .svchub import SvcHub
- if TYPE_CHECKING:
- from .svchub import SvcHub
-except ImportError:
+try:
+ import typing
+ from typing import Any, Optional
+except:
pass
class FtpAuth(DummyAuthorizer):
- def __init__(self):
+ def __init__(self, hub: "SvcHub") -> None:
super(FtpAuth, self).__init__()
- self.hub = None # type: SvcHub
+ self.hub = hub
- def validate_authentication(self, username, password, handler):
+ def validate_authentication(
+ self, username: str, password: str, handler: Any
+ ) -> None:
asrv = self.hub.asrv
if username == "anonymous":
password = ""
uname = "*"
if password:
- uname = asrv.iacct.get(password, None)
+ uname = asrv.iacct.get(password, "")
handler.username = uname
if password and not uname:
raise AuthenticationFailed("Authentication failed.")
- def get_home_dir(self, username):
+ def get_home_dir(self, username: str) -> str:
return "/"
- def has_user(self, username):
+ def has_user(self, username: str) -> bool:
asrv = self.hub.asrv
return username in asrv.acct
- def has_perm(self, username, perm, path=None):
+ def has_perm(self, username: str, perm: int, path: Optional[str] = None) -> bool:
return True # handled at filesystem layer
- def get_perms(self, username):
+ def get_perms(self, username: str) -> str:
return "elradfmwMT"
- def get_msg_login(self, username):
+ def get_msg_login(self, username: str) -> str:
return "sup {}".format(username)
- def get_msg_quit(self, username):
+ def get_msg_quit(self, username: str) -> str:
return "cya"
class FtpFs(AbstractedFS):
- def __init__(self, root, cmd_channel):
+ def __init__(
+ self, root: str, cmd_channel: Any
+ ) -> None: # pylint: disable=super-init-not-called
self.h = self.cmd_channel = cmd_channel # type: FTPHandler
- self.hub = cmd_channel.hub # type: SvcHub
+ self.hub: "SvcHub" = cmd_channel.hub
self.args = cmd_channel.args
self.uname = self.hub.asrv.iacct.get(cmd_channel.password, "*")
@@ -90,7 +95,14 @@ class FtpFs(AbstractedFS):
self.listdirinfo = self.listdir
self.chdir(".")
- def v2a(self, vpath, r=False, w=False, m=False, d=False):
+ def v2a(
+ self,
+ vpath: str,
+ r: bool = False,
+ w: bool = False,
+ m: bool = False,
+ d: bool = False,
+ ) -> str:
try:
vpath = vpath.replace("\\", "/").lstrip("/")
vfs, rem = self.hub.asrv.vfs.get(vpath, self.uname, r, w, m, d)
@@ -101,25 +113,32 @@ class FtpFs(AbstractedFS):
except Pebkac as ex:
raise FilesystemError(str(ex))
- def rv2a(self, vpath, r=False, w=False, m=False, d=False):
+ def rv2a(
+ self,
+ vpath: str,
+ r: bool = False,
+ w: bool = False,
+ m: bool = False,
+ d: bool = False,
+ ) -> str:
return self.v2a(os.path.join(self.cwd, vpath), r, w, m, d)
- def ftp2fs(self, ftppath):
+ def ftp2fs(self, ftppath: str) -> str:
# return self.v2a(ftppath)
return ftppath # self.cwd must be vpath
- def fs2ftp(self, fspath):
+ def fs2ftp(self, fspath: str) -> str:
# raise NotImplementedError()
return fspath
- def validpath(self, path):
+ def validpath(self, path: str) -> bool:
if "/.hist/" in path:
if "/up2k." in path or path.endswith("/dir.txt"):
raise FilesystemError("access to this file is forbidden")
return True
- def open(self, filename, mode):
+ def open(self, filename: str, mode: str) -> typing.IO[Any]:
r = "r" in mode
w = "w" in mode or "a" in mode or "+" in mode
@@ -130,24 +149,24 @@ class FtpFs(AbstractedFS):
self.validpath(ap)
return open(fsenc(ap), mode)
- def chdir(self, path):
+ def chdir(self, path: str) -> None:
self.cwd = join(self.cwd, path)
x = self.hub.asrv.vfs.can_access(self.cwd.lstrip("/"), self.h.username)
self.can_read, self.can_write, self.can_move, self.can_delete, self.can_get = x
- def mkdir(self, path):
+ def mkdir(self, path: str) -> None:
ap = self.rv2a(path, w=True)
bos.mkdir(ap)
- def listdir(self, path):
+ def listdir(self, path: str) -> list[str]:
vpath = join(self.cwd, path).lstrip("/")
try:
vfs, rem = self.hub.asrv.vfs.get(vpath, self.uname, True, False)
- fsroot, vfs_ls, vfs_virt = vfs.ls(
+ fsroot, vfs_ls1, vfs_virt = vfs.ls(
rem, self.uname, not self.args.no_scandir, [[True], [False, True]]
)
- vfs_ls = [x[0] for x in vfs_ls]
+ vfs_ls = [x[0] for x in vfs_ls1]
vfs_ls.extend(vfs_virt.keys())
if not self.args.ed:
@@ -164,11 +183,11 @@ class FtpFs(AbstractedFS):
r = {x.split("/")[0]: 1 for x in self.hub.asrv.vfs.all_vols.keys()}
return list(sorted(list(r.keys())))
- def rmdir(self, path):
+ def rmdir(self, path: str) -> None:
ap = self.rv2a(path, d=True)
bos.rmdir(ap)
- def remove(self, path):
+ def remove(self, path: str) -> None:
if self.args.no_del:
raise FilesystemError("the delete feature is disabled in server config")
@@ -178,13 +197,13 @@ class FtpFs(AbstractedFS):
except Exception as ex:
raise FilesystemError(str(ex))
- def rename(self, src, dst):
+ def rename(self, src: str, dst: str) -> None:
if not self.can_move:
raise FilesystemError("not allowed for user " + self.h.username)
if self.args.no_mv:
- m = "the rename/move feature is disabled in server config"
- raise FilesystemError(m)
+ t = "the rename/move feature is disabled in server config"
+ raise FilesystemError(t)
svp = join(self.cwd, src).lstrip("/")
dvp = join(self.cwd, dst).lstrip("/")
@@ -193,10 +212,10 @@ class FtpFs(AbstractedFS):
except Exception as ex:
raise FilesystemError(str(ex))
- def chmod(self, path, mode):
+ def chmod(self, path: str, mode: str) -> None:
pass
- def stat(self, path):
+ def stat(self, path: str) -> os.stat_result:
try:
ap = self.rv2a(path, r=True)
return bos.stat(ap)
@@ -208,59 +227,59 @@ class FtpFs(AbstractedFS):
return st
- def utime(self, path, timeval):
+ def utime(self, path: str, timeval: float) -> None:
ap = self.rv2a(path, w=True)
return bos.utime(ap, (timeval, timeval))
- def lstat(self, path):
+ def lstat(self, path: str) -> os.stat_result:
ap = self.rv2a(path)
return bos.lstat(ap)
- def isfile(self, path):
+ def isfile(self, path: str) -> bool:
st = self.stat(path)
return stat.S_ISREG(st.st_mode)
- def islink(self, path):
+ def islink(self, path: str) -> bool:
ap = self.rv2a(path)
return bos.path.islink(ap)
- def isdir(self, path):
+ def isdir(self, path: str) -> bool:
try:
st = self.stat(path)
return stat.S_ISDIR(st.st_mode)
except:
return True
- def getsize(self, path):
+ def getsize(self, path: str) -> int:
ap = self.rv2a(path)
return bos.path.getsize(ap)
- def getmtime(self, path):
+ def getmtime(self, path: str) -> float:
ap = self.rv2a(path)
return bos.path.getmtime(ap)
- def realpath(self, path):
+ def realpath(self, path: str) -> str:
return path
- def lexists(self, path):
+ def lexists(self, path: str) -> bool:
ap = self.rv2a(path)
return bos.path.lexists(ap)
- def get_user_by_uid(self, uid):
+ def get_user_by_uid(self, uid: int) -> str:
return "root"
- def get_group_by_uid(self, gid):
+ def get_group_by_uid(self, gid: int) -> str:
return "root"
class FtpHandler(FTPHandler):
abstracted_fs = FtpFs
- hub = None # type: SvcHub
- args = None # type: argparse.Namespace
+ hub: "SvcHub" = None
+ args: argparse.Namespace = None
- def __init__(self, conn, server, ioloop=None):
- self.hub = FtpHandler.hub # type: SvcHub
- self.args = FtpHandler.args # type: argparse.Namespace
+ def __init__(self, conn: Any, server: Any, ioloop: Any = None) -> None:
+ self.hub: "SvcHub" = FtpHandler.hub
+ self.args: argparse.Namespace = FtpHandler.args
if PY2:
FTPHandler.__init__(self, conn, server, ioloop)
@@ -268,9 +287,10 @@ class FtpHandler(FTPHandler):
super(FtpHandler, self).__init__(conn, server, ioloop)
# abspath->vpath mapping to resolve log_transfer paths
- self.vfs_map = {}
+ self.vfs_map: dict[str, str] = {}
- def ftp_STOR(self, file, mode="w"):
+ def ftp_STOR(self, file: str, mode: str = "w") -> Any:
+ # Optional[str]
vp = join(self.fs.cwd, file).lstrip("/")
ap = self.fs.v2a(vp)
self.vfs_map[ap] = vp
@@ -279,7 +299,16 @@ class FtpHandler(FTPHandler):
# print("ftp_STOR: {} {} OK".format(vp, mode))
return ret
- def log_transfer(self, cmd, filename, receive, completed, elapsed, bytes):
+ def log_transfer(
+ self,
+ cmd: str,
+ filename: bytes,
+ receive: bool,
+ completed: bool,
+ elapsed: float,
+ bytes: int,
+ ) -> Any:
+ # None
ap = filename.decode("utf-8", "replace")
vp = self.vfs_map.pop(ap, None)
# print("xfer_end: {} => {}".format(ap, vp))
@@ -312,7 +341,7 @@ except:
class Ftpd(object):
- def __init__(self, hub):
+ def __init__(self, hub: "SvcHub") -> None:
self.hub = hub
self.args = hub.args
@@ -321,24 +350,23 @@ class Ftpd(object):
hs.append([FtpHandler, self.args.ftp])
if self.args.ftps:
try:
- h = SftpHandler
+ h1 = SftpHandler
except:
- m = "\nftps requires pyopenssl;\nplease run the following:\n\n {} -m pip install --user pyopenssl\n"
- print(m.format(sys.executable))
+ t = "\nftps requires pyopenssl;\nplease run the following:\n\n {} -m pip install --user pyopenssl\n"
+ print(t.format(sys.executable))
sys.exit(1)
- h.certfile = os.path.join(E.cfg, "cert.pem")
- h.tls_control_required = True
- h.tls_data_required = True
+ h1.certfile = os.path.join(E.cfg, "cert.pem")
+ h1.tls_control_required = True
+ h1.tls_data_required = True
- hs.append([h, self.args.ftps])
+ hs.append([h1, self.args.ftps])
- for h in hs:
- h, lp = h
- h.hub = hub
- h.args = hub.args
- h.authorizer = FtpAuth()
- h.authorizer.hub = hub
+ for h_lp in hs:
+ h2, lp = h_lp
+ h2.hub = hub
+ h2.args = hub.args
+ h2.authorizer = FtpAuth(hub)
if self.args.ftp_pr:
p1, p2 = [int(x) for x in self.args.ftp_pr.split("-")]
@@ -350,10 +378,10 @@ class Ftpd(object):
else:
p1 += d + 1
- h.passive_ports = list(range(p1, p2 + 1))
+ h2.passive_ports = list(range(p1, p2 + 1))
if self.args.ftp_nat:
- h.masquerade_address = self.args.ftp_nat
+ h2.masquerade_address = self.args.ftp_nat
if self.args.ftp_dbg:
config_logging(level=logging.DEBUG)
@@ -363,11 +391,11 @@ class Ftpd(object):
for h, lp in hs:
FTPServer((ip, int(lp)), h, ioloop)
- t = threading.Thread(target=ioloop.loop)
- t.daemon = True
- t.start()
+ thr = threading.Thread(target=ioloop.loop)
+ thr.daemon = True
+ thr.start()
-def join(p1, p2):
+def join(p1: str, p2: str) -> str:
w = os.path.join(p1, p2.replace("\\", "/"))
return os.path.normpath(w).replace("\\", "/")
diff --git a/copyparty/httpcli.py b/copyparty/httpcli.py
index 5368bf62..cfc2bbf7 100644
--- a/copyparty/httpcli.py
+++ b/copyparty/httpcli.py
@@ -1,18 +1,23 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import os
-import stat
-import gzip
-import time
-import copy
-import json
+import argparse # typechk
import base64
-import string
+import calendar
+import copy
+import gzip
+import json
+import os
+import re
import socket
+import stat
+import string
+import threading # typechk
+import time
from datetime import datetime
from operator import itemgetter
-import calendar
+
+import jinja2 # typechk
try:
import lzma
@@ -24,13 +29,62 @@ try:
except:
pass
-from .__init__ import E, PY2, WINDOWS, ANYWIN, unicode
-from .util import * # noqa # pylint: disable=unused-wildcard-import
+from .__init__ import ANYWIN, PY2, TYPE_CHECKING, WINDOWS, E, unicode
+from .authsrv import VFS # typechk
from .bos import bos
-from .authsrv import AuthSrv
-from .szip import StreamZip
from .star import StreamTar
+from .sutil import StreamArc # typechk
+from .szip import StreamZip
+from .util import (
+ HTTP_TS_FMT,
+ HTTPCODE,
+ META_NOBOTS,
+ MultipartParser,
+ Pebkac,
+ UnrecvEOF,
+ alltrace,
+ exclude_dotfiles,
+ fsenc,
+ gen_filekey,
+ gencookie,
+ get_spd,
+ guess_mime,
+ gzip_orig_sz,
+ hashcopy,
+ html_bescape,
+ html_escape,
+ http_ts,
+ humansize,
+ min_ex,
+ quotep,
+ read_header,
+ read_socket,
+ read_socket_chunked,
+ read_socket_unbounded,
+ relchk,
+ ren_open,
+ s3enc,
+ sanitize_fn,
+ sendfile_kern,
+ sendfile_py,
+ undot,
+ unescape_cookie,
+ unquote,
+ unquotep,
+ vol_san,
+ vsplit,
+ yieldfile,
+)
+try:
+ from typing import Any, Generator, Match, Optional, Pattern, Type, Union
+except:
+ pass
+
+if TYPE_CHECKING:
+ from .httpconn import HttpConn
+
+_ = (argparse, threading)
NO_CACHE = {"Cache-Control": "no-cache"}
@@ -40,27 +94,60 @@ class HttpCli(object):
Spawned by HttpConn to process one http transaction
"""
- def __init__(self, conn):
+ def __init__(self, conn: "HttpConn") -> None:
+ assert conn.sr
+
self.t0 = time.time()
self.conn = conn
- self.mutex = conn.mutex
- self.s = conn.s # type: socket
- self.sr = conn.sr # type: Unrecv
+ self.mutex = conn.mutex # mypy404
+ self.s = conn.s
+ self.sr = conn.sr
self.ip = conn.addr[0]
- self.addr = conn.addr # type: tuple[str, int]
- self.args = conn.args
- self.asrv = conn.asrv # type: AuthSrv
- self.ico = conn.ico
- self.thumbcli = conn.thumbcli
- self.u2fh = conn.u2fh
- self.log_func = conn.log_func
- self.log_src = conn.log_src
- self.tls = hasattr(self.s, "cipher")
+ self.addr: tuple[str, int] = conn.addr
+ self.args = conn.args # mypy404
+ self.asrv = conn.asrv # mypy404
+ self.ico = conn.ico # mypy404
+ self.thumbcli = conn.thumbcli # mypy404
+ self.u2fh = conn.u2fh # mypy404
+ self.log_func = conn.log_func # mypy404
+ self.log_src = conn.log_src # mypy404
+ self.tls: bool = hasattr(self.s, "cipher")
+
+ # placeholders; assigned by run()
+ self.keepalive = False
+ self.is_https = False
+ self.headers: dict[str, str] = {}
+ self.mode = " "
+ self.req = " "
+ self.http_ver = " "
+ self.ua = " "
+ self.is_rclone = False
+ self.is_ancient = False
+ self.dip = " "
+ self.ouparam: dict[str, str] = {}
+ self.uparam: dict[str, str] = {}
+ self.cookies: dict[str, str] = {}
+ self.vpath = " "
+ self.uname = " "
+ self.rvol = [" "]
+ self.wvol = [" "]
+ self.mvol = [" "]
+ self.dvol = [" "]
+ self.gvol = [" "]
+ self.do_log = True
+ self.can_read = False
+ self.can_write = False
+ self.can_move = False
+ self.can_delete = False
+ self.can_get = False
+ # post
+ self.parser: Optional[MultipartParser] = None
+ # end placeholders
self.bufsz = 1024 * 32
- self.hint = None
+ self.hint = ""
self.trailing_slash = True
- self.out_headerlist = []
+ self.out_headerlist: list[tuple[str, str]] = []
self.out_headers = {
"Access-Control-Allow-Origin": "*",
"Cache-Control": "no-store; max-age=0",
@@ -71,44 +158,44 @@ class HttpCli(object):
self.out_headers["X-Robots-Tag"] = "noindex, nofollow"
self.html_head = h
- def log(self, msg, c=0):
+ def log(self, msg: str, c: Union[int, str] = 0) -> None:
ptn = self.asrv.re_pwd
if ptn and ptn.search(msg):
msg = ptn.sub(self.unpwd, msg)
self.log_func(self.log_src, msg, c)
- def unpwd(self, m):
+ def unpwd(self, m: Match[str]) -> str:
a, b = m.groups()
return "=\033[7m {} \033[27m{}".format(self.asrv.iacct[a], b)
- def _check_nonfatal(self, ex, post):
+ def _check_nonfatal(self, ex: Pebkac, post: bool) -> bool:
if post:
return ex.code < 300
return ex.code < 400 or ex.code in [404, 429]
- def _assert_safe_rem(self, rem):
+ def _assert_safe_rem(self, rem: str) -> None:
# sanity check to prevent any disasters
if rem.startswith("/") or rem.startswith("../") or "/../" in rem:
raise Exception("that was close")
- def j2(self, name, **ka):
+ def j2s(self, name: str, **ka: Any) -> str:
tpl = self.conn.hsrv.j2[name]
- if ka:
- ka["ts"] = self.conn.hsrv.cachebuster()
- ka["svcname"] = self.args.doctitle
- ka["html_head"] = self.html_head
- return tpl.render(**ka)
+ ka["ts"] = self.conn.hsrv.cachebuster()
+ ka["svcname"] = self.args.doctitle
+ ka["html_head"] = self.html_head
+ return tpl.render(**ka) # type: ignore
- return tpl
+ def j2j(self, name: str) -> jinja2.Template:
+ return self.conn.hsrv.j2[name]
- def run(self):
+ def run(self) -> bool:
"""returns true if connection can be reused"""
self.keepalive = False
self.is_https = False
self.headers = {}
- self.hint = None
+ self.hint = ""
try:
headerlines = read_header(self.sr)
if not headerlines:
@@ -125,8 +212,8 @@ class HttpCli(object):
# normalize incoming headers to lowercase;
# outgoing headers however are Correct-Case
for header_line in headerlines[1:]:
- k, v = header_line.split(":", 1)
- self.headers[k.lower()] = v.strip()
+ k, zs = header_line.split(":", 1)
+ self.headers[k.lower()] = zs.strip()
except:
msg = " ]\n#[ ".join(headerlines)
raise Pebkac(400, "bad headers:\n#[ " + msg + " ]")
@@ -147,23 +234,26 @@ class HttpCli(object):
self.is_rclone = self.ua.startswith("rclone/")
self.is_ancient = self.ua.startswith("Mozilla/4.")
- v = self.headers.get("connection", "").lower()
- self.keepalive = not v.startswith("close") and self.http_ver != "HTTP/1.0"
- self.is_https = (self.headers.get("x-forwarded-proto", "").lower() == "https" or self.tls)
+ zs = self.headers.get("connection", "").lower()
+ self.keepalive = not zs.startswith("close") and self.http_ver != "HTTP/1.0"
+ self.is_https = (
+ self.headers.get("x-forwarded-proto", "").lower() == "https" or self.tls
+ )
n = self.args.rproxy
if n:
- v = self.headers.get("x-forwarded-for")
- if v and self.conn.addr[0] in ["127.0.0.1", "::1"]:
+ zso = self.headers.get("x-forwarded-for")
+ if zso and self.conn.addr[0] in ["127.0.0.1", "::1"]:
if n > 0:
n -= 1
- vs = v.split(",")
+ zsl = zso.split(",")
try:
- self.ip = vs[n].strip()
+ self.ip = zsl[n].strip()
except:
- self.ip = vs[0].strip()
- self.log("rproxy={} oob x-fwd {}".format(self.args.rproxy, v), c=3)
+ self.ip = zsl[0].strip()
+ t = "rproxy={} oob x-fwd {}"
+ self.log(t.format(self.args.rproxy, zso), c=3)
self.log_src = self.conn.set_rproxy(self.ip)
@@ -175,9 +265,9 @@ class HttpCli(object):
keys = list(sorted(self.headers.keys()))
for k in keys:
- v = self.headers.get(k)
- if v is not None:
- self.log("[H] {}: \033[33m[{}]".format(k, v), 6)
+ zso = self.headers.get(k)
+ if zso is not None:
+ self.log("[H] {}: \033[33m[{}]".format(k, zso), 6)
if "&" in self.req and "?" not in self.req:
self.hint = "did you mean '?' instead of '&'"
@@ -193,20 +283,22 @@ class HttpCli(object):
vpath = undot(vpath)
for k in arglist.split("&"):
if "=" in k:
- k, v = k.split("=", 1)
- uparam[k.lower()] = v.strip()
+ k, zs = k.split("=", 1)
+ uparam[k.lower()] = zs.strip()
else:
- uparam[k.lower()] = False
+ uparam[k.lower()] = ""
- self.ouparam = {k: v for k, v in uparam.items()}
+ self.ouparam = {k: zs for k, zs in uparam.items()}
- cookies = self.headers.get("cookie") or {}
- if cookies:
- cookies = [x.split("=", 1) for x in cookies.split(";") if "=" in x]
- cookies = {k.strip(): unescape_cookie(v) for k, v in cookies}
+ zso = self.headers.get("cookie")
+ if zso:
+ zsll = [x.split("=", 1) for x in zso.split(";") if "=" in x]
+ cookies = {k.strip(): unescape_cookie(zs) for k, zs in zsll}
for kc, ku in [["cppwd", "pw"], ["b", "b"]]:
if kc in cookies and ku not in uparam:
uparam[ku] = cookies[kc]
+ else:
+ cookies = {}
if len(uparam) > 10 or len(cookies) > 50:
raise Pebkac(400, "u wot m8")
@@ -223,22 +315,22 @@ class HttpCli(object):
self.log("invalid relpath [{}]".format(self.vpath))
return self.tx_404() and self.keepalive
- pwd = None
- ba = self.headers.get("authorization")
- if ba:
+ pwd = ""
+ zso = self.headers.get("authorization")
+ if zso:
try:
- ba = ba.split(" ")[1].encode("ascii")
- ba = base64.b64decode(ba).decode("utf-8")
+ zb = zso.split(" ")[1].encode("ascii")
+ zs = base64.b64decode(zb).decode("utf-8")
# try "pwd", "x:pwd", "pwd:x"
- for ba in [ba] + ba.split(":", 1)[::-1]:
- if self.asrv.iacct.get(ba):
- pwd = ba
+ for zs in [zs] + zs.split(":", 1)[::-1]:
+ if self.asrv.iacct.get(zs):
+ pwd = zs
break
except:
pass
pwd = uparam.get("pw") or pwd
- self.uname = self.asrv.iacct.get(pwd, "*")
+ self.uname = self.asrv.iacct.get(pwd) or "*"
self.rvol = self.asrv.vfs.aread[self.uname]
self.wvol = self.asrv.vfs.awrite[self.uname]
self.mvol = self.asrv.vfs.amove[self.uname]
@@ -249,12 +341,13 @@ class HttpCli(object):
self.out_headerlist.append(("Set-Cookie", self.get_pwd_cookie(pwd)[0]))
if self.is_rclone:
- uparam["raw"] = False
- uparam["dots"] = False
- uparam["b"] = False
- cookies["b"] = False
+ uparam["raw"] = ""
+ uparam["dots"] = ""
+ uparam["b"] = ""
+ cookies["b"] = ""
- self.do_log = not self.conn.lf_url or not self.conn.lf_url.search(self.req)
+ ptn: Optional[Pattern[str]] = self.conn.lf_url # mypy404
+ self.do_log = not ptn or not ptn.search(self.req)
x = self.asrv.vfs.can_access(self.vpath, self.uname)
self.can_read, self.can_write, self.can_move, self.can_delete, self.can_get = x
@@ -272,9 +365,10 @@ class HttpCli(object):
raise Pebkac(400, 'invalid HTTP mode "{0}"'.format(self.mode))
except Exception as ex:
- pex = ex
if not hasattr(ex, "code"):
pex = Pebkac(500)
+ else:
+ pex = ex # type: ignore
try:
post = self.mode in ["POST", "PUT"] or "content-length" in self.headers
@@ -294,7 +388,7 @@ class HttpCli(object):
except Pebkac:
return False
- def permit_caching(self):
+ def permit_caching(self) -> None:
cache = self.uparam.get("cache")
if cache is None:
self.out_headers.update(NO_CACHE)
@@ -303,11 +397,17 @@ class HttpCli(object):
n = "604800" if cache == "i" else cache or "69"
self.out_headers["Cache-Control"] = "max-age=" + n
- def k304(self):
+ def k304(self) -> bool:
k304 = self.cookies.get("k304")
return k304 == "y" or ("; Trident/" in self.ua and not k304)
- def send_headers(self, length, status=200, mime=None, headers=None):
+ def send_headers(
+ self,
+ length: Optional[int],
+ status: int = 200,
+ mime: Optional[str] = None,
+ headers: Optional[dict[str, str]] = None,
+ ) -> None:
response = ["{} {} {}".format(self.http_ver, status, HTTPCODE[status])]
if length is not None:
@@ -327,10 +427,11 @@ class HttpCli(object):
if not mime:
mime = self.out_headers.get("Content-Type", "text/html; charset=utf-8")
+ assert mime
self.out_headers["Content-Type"] = mime
- for k, v in list(self.out_headers.items()) + self.out_headerlist:
- response.append("{}: {}".format(k, v))
+ for k, zs in list(self.out_headers.items()) + self.out_headerlist:
+ response.append("{}: {}".format(k, zs))
try:
# best practice to separate headers and body into different packets
@@ -338,11 +439,19 @@ class HttpCli(object):
except:
raise Pebkac(400, "client d/c while replying headers")
- def reply(self, body, status=200, mime=None, headers=None, volsan=False):
+ def reply(
+ self,
+ body: bytes,
+ status: int = 200,
+ mime: Optional[str] = None,
+ headers: Optional[dict[str, str]] = None,
+ volsan: bool = False,
+ ) -> bytes:
# TODO something to reply with user-supplied values safely
if volsan:
- body = vol_san(self.asrv.vfs.all_vols.values(), body)
+ vols = list(self.asrv.vfs.all_vols.values())
+ body = vol_san(vols, body)
self.send_headers(len(body), status, mime, headers)
@@ -354,17 +463,19 @@ class HttpCli(object):
return body
- def loud_reply(self, body, *args, **kwargs):
+ def loud_reply(self, body: str, *args: Any, **kwargs: Any) -> None:
if not kwargs.get("mime"):
kwargs["mime"] = "text/plain; charset=utf-8"
self.log(body.rstrip())
self.reply(body.encode("utf-8") + b"\r\n", *list(args), **kwargs)
- def urlq(self, add, rm):
+ def urlq(self, add: dict[str, str], rm: list[str]) -> str:
"""
generates url query based on uparam (b, pw, all others)
removing anything in rm, adding pairs in add
+
+ also list faster than set until ~20 items
"""
if self.is_rclone:
@@ -372,28 +483,28 @@ class HttpCli(object):
cmap = {"pw": "cppwd"}
kv = {
- k: v
- for k, v in self.uparam.items()
- if k not in rm and self.cookies.get(cmap.get(k, k)) != v
+ k: zs
+ for k, zs in self.uparam.items()
+ if k not in rm and self.cookies.get(cmap.get(k, k)) != zs
}
kv.update(add)
if not kv:
return ""
- r = ["{}={}".format(k, quotep(v)) if v else k for k, v in kv.items()]
+ r = ["{}={}".format(k, quotep(zs)) if zs else k for k, zs in kv.items()]
return "?" + "&".join(r)
def redirect(
self,
- vpath,
- suf="",
- msg="aight",
- flavor="go to",
- click=True,
- status=200,
- use302=False,
- ):
- html = self.j2(
+ vpath: str,
+ suf: str = "",
+ msg: str = "aight",
+ flavor: str = "go to",
+ click: bool = True,
+ status: int = 200,
+ use302: bool = False,
+ ) -> bool:
+ html = self.j2s(
"msg",
h2='{} /{}'.format(
quotep(vpath) + suf, flavor, html_escape(vpath, crlf=True) + suf
@@ -407,7 +518,9 @@ class HttpCli(object):
else:
self.reply(html, status=status)
- def handle_get(self):
+ return True
+
+ def handle_get(self) -> bool:
if self.do_log:
logmsg = "{:4} {}".format(self.mode, self.req)
@@ -434,13 +547,13 @@ class HttpCli(object):
self.log("inaccessible: [{}]".format(self.vpath))
return self.tx_404(True)
- self.uparam["h"] = False
+ self.uparam["h"] = ""
if "tree" in self.uparam:
return self.tx_tree()
if "delete" in self.uparam:
- return self.handle_rm()
+ return self.handle_rm([])
if "move" in self.uparam:
return self.handle_mv()
@@ -486,7 +599,7 @@ class HttpCli(object):
return self.tx_browser()
- def handle_options(self):
+ def handle_options(self) -> bool:
if self.do_log:
self.log("OPTIONS " + self.req)
@@ -501,7 +614,7 @@ class HttpCli(object):
)
return True
- def handle_put(self):
+ def handle_put(self) -> bool:
self.log("PUT " + self.req)
if self.headers.get("expect", "").lower() == "100-continue":
@@ -512,7 +625,7 @@ class HttpCli(object):
return self.handle_stash()
- def handle_post(self):
+ def handle_post(self) -> bool:
self.log("POST " + self.req)
if self.headers.get("expect", "").lower() == "100-continue":
@@ -549,16 +662,16 @@ class HttpCli(object):
reader, _ = self.get_body_reader()
for buf in reader:
orig = buf.decode("utf-8", "replace")
- m = "urlform_raw {} @ {}\n {}\n"
- self.log(m.format(len(orig), self.vpath, orig))
+ t = "urlform_raw {} @ {}\n {}\n"
+ self.log(t.format(len(orig), self.vpath, orig))
try:
- plain = unquote(buf.replace(b"+", b" "))
- plain = plain.decode("utf-8", "replace")
+ zb = unquote(buf.replace(b"+", b" "))
+ plain = zb.decode("utf-8", "replace")
if buf.startswith(b"msg="):
plain = plain[4:]
- m = "urlform_dec {} @ {}\n {}\n"
- self.log(m.format(len(plain), self.vpath, plain))
+ t = "urlform_dec {} @ {}\n {}\n"
+ self.log(t.format(len(plain), self.vpath, plain))
except Exception as ex:
self.log(repr(ex))
@@ -569,7 +682,7 @@ class HttpCli(object):
raise Pebkac(405, "don't know how to handle POST({})".format(ctype))
- def get_body_reader(self):
+ def get_body_reader(self) -> tuple[Generator[bytes, None, None], int]:
if "chunked" in self.headers.get("transfer-encoding", "").lower():
return read_socket_chunked(self.sr), -1
@@ -580,7 +693,8 @@ class HttpCli(object):
else:
return read_socket(self.sr, remains), remains
- def dump_to_file(self):
+ def dump_to_file(self) -> tuple[int, str, str, int, str, str]:
+ # post_sz, sha_hex, sha_b64, remains, path, url
reader, remains = self.get_body_reader()
vfs, rem = self.asrv.vfs.get(self.vpath, self.uname, False, True)
lim = vfs.get_dbv(rem)[0].lim
@@ -595,7 +709,7 @@ class HttpCli(object):
bos.makedirs(fdir)
- open_ka = {"fun": open}
+ open_ka: dict[str, Any] = {"fun": open}
open_a = ["wb", 512 * 1024]
# user-request || config-force
@@ -607,7 +721,7 @@ class HttpCli(object):
):
fb = {"gz": 9, "xz": 0} # default/fallback level
lv = {} # selected level
- alg = None # selected algo (gz=preferred)
+ alg = "" # selected algo (gz=preferred)
# user-prefs first
if "gz" in self.uparam or "pk" in self.uparam: # def.pk
@@ -615,8 +729,8 @@ class HttpCli(object):
if "xz" in self.uparam:
alg = "xz"
if alg:
- v = self.uparam.get(alg)
- lv[alg] = fb[alg] if v is None else int(v)
+ zso = self.uparam.get(alg)
+ lv[alg] = fb[alg] if zso is None else int(zso)
if alg not in vfs.flags:
alg = "gz" if "gz" in vfs.flags else "xz"
@@ -633,7 +747,7 @@ class HttpCli(object):
except:
pass
- lv[alg] = lv.get(alg) or fb.get(alg)
+ lv[alg] = lv.get(alg) or fb.get(alg) or 0
self.log("compressing with {} level {}".format(alg, lv.get(alg)))
if alg == "gz":
@@ -656,8 +770,8 @@ class HttpCli(object):
if not fn:
fn = "put" + suffix
- with ren_open(fn, *open_a, **params) as f:
- f, fn = f["orz"]
+ with ren_open(fn, *open_a, **params) as zfw:
+ f, fn = zfw["orz"]
path = os.path.join(fdir, fn)
post_sz, sha_hex, sha_b64 = hashcopy(reader, f, self.args.s_wr_slp)
@@ -674,8 +788,7 @@ class HttpCli(object):
return post_sz, sha_hex, sha_b64, remains, path, ""
vfs, rem = vfs.get_dbv(rem)
- self.conn.hsrv.broker.put(
- False,
+ self.conn.hsrv.broker.say(
"up2k.hash_file",
vfs.realpath,
vfs.flags,
@@ -705,16 +818,16 @@ class HttpCli(object):
return post_sz, sha_hex, sha_b64, remains, path, url
- def handle_stash(self):
+ def handle_stash(self) -> bool:
post_sz, sha_hex, sha_b64, remains, path, url = self.dump_to_file()
spd = self._spd(post_sz)
- m = "{} wrote {}/{} bytes to {} # {}"
- self.log(m.format(spd, post_sz, remains, path, sha_b64[:28])) # 21
- m = "{}\n{}\n{}\n{}\n".format(post_sz, sha_b64, sha_hex[:56], url)
- self.reply(m.encode("utf-8"))
+ t = "{} wrote {}/{} bytes to {} # {}"
+ self.log(t.format(spd, post_sz, remains, path, sha_b64[:28])) # 21
+ t = "{}\n{}\n{}\n{}\n".format(post_sz, sha_b64, sha_hex[:56], url)
+ self.reply(t.encode("utf-8"))
return True
- def _spd(self, nbytes, add=True):
+ def _spd(self, nbytes: int, add: bool = True) -> str:
if add:
self.conn.nbyte += nbytes
@@ -722,7 +835,7 @@ class HttpCli(object):
spd2 = get_spd(self.conn.nbyte, self.conn.t0)
return "{} {} n{}".format(spd1, spd2, self.conn.nreq)
- def handle_post_multipart(self):
+ def handle_post_multipart(self) -> bool:
self.parser = MultipartParser(self.log, self.sr, self.headers)
self.parser.parse()
@@ -749,7 +862,8 @@ class HttpCli(object):
raise Pebkac(422, 'invalid action "{}"'.format(act))
- def handle_zip_post(self):
+ def handle_zip_post(self) -> bool:
+ assert self.parser
for k in ["zip", "tar"]:
v = self.uparam.get(k)
if v is not None:
@@ -759,17 +873,17 @@ class HttpCli(object):
raise Pebkac(422, "need zip or tar keyword")
vn, rem = self.asrv.vfs.get(self.vpath, self.uname, True, False)
- items = self.parser.require("files", 1024 * 1024)
- if not items:
+ zs = self.parser.require("files", 1024 * 1024)
+ if not zs:
raise Pebkac(422, "need files list")
- items = items.replace("\r", "").split("\n")
+ items = zs.replace("\r", "").split("\n")
items = [unquotep(x) for x in items if items]
self.parser.drop()
return self.tx_zip(k, v, vn, rem, items, self.args.ed)
- def handle_post_json(self):
+ def handle_post_json(self) -> bool:
try:
remains = int(self.headers["content-length"])
except:
@@ -836,14 +950,14 @@ class HttpCli(object):
except:
raise Pebkac(500, min_ex())
- x = self.conn.hsrv.broker.put(True, "up2k.handle_json", body)
+ x = self.conn.hsrv.broker.ask("up2k.handle_json", body)
ret = x.get()
ret = json.dumps(ret)
self.log(ret)
self.reply(ret.encode("utf-8"), mime="application/json")
return True
- def handle_search(self, body):
+ def handle_search(self, body: dict[str, Any]) -> bool:
idx = self.conn.get_u2idx()
if not hasattr(idx, "p_end"):
raise Pebkac(500, "sqlite3 is not available on the server; cannot search")
@@ -857,15 +971,15 @@ class HttpCli(object):
continue
seen[vfs] = True
- vols.append([vfs.vpath, vfs.realpath, vfs.flags])
+ vols.append((vfs.vpath, vfs.realpath, vfs.flags))
t0 = time.time()
if idx.p_end:
penalty = 0.7
t_idle = t0 - idx.p_end
if idx.p_dur > 0.7 and t_idle < penalty:
- m = "rate-limit {:.1f} sec, cost {:.2f}, idle {:.2f}"
- raise Pebkac(429, m.format(penalty, idx.p_dur, t_idle))
+ t = "rate-limit {:.1f} sec, cost {:.2f}, idle {:.2f}"
+ raise Pebkac(429, t.format(penalty, idx.p_dur, t_idle))
if "srch" in body:
# search by up2k hashlist
@@ -873,8 +987,8 @@ class HttpCli(object):
vbody["hash"] = len(vbody["hash"])
self.log("qj: " + repr(vbody))
hits = idx.fsearch(vols, body)
- msg = repr(hits)
- taglist = {}
+ msg: Any = repr(hits)
+ taglist: list[str] = []
else:
# search by query params
q = body["q"]
@@ -900,7 +1014,7 @@ class HttpCli(object):
self.reply(r, mime="application/json")
return True
- def handle_post_binary(self):
+ def handle_post_binary(self) -> bool:
try:
remains = int(self.headers["content-length"])
except:
@@ -915,7 +1029,7 @@ class HttpCli(object):
vfs, _ = self.asrv.vfs.get(self.vpath, self.uname, False, True)
ptop = (vfs.dbv or vfs).realpath
- x = self.conn.hsrv.broker.put(True, "up2k.handle_chunk", ptop, wark, chash)
+ x = self.conn.hsrv.broker.ask("up2k.handle_chunk", ptop, wark, chash)
response = x.get()
chunksize, cstart, path, lastmod = response
@@ -946,8 +1060,8 @@ class HttpCli(object):
post_sz, _, sha_b64 = hashcopy(reader, f, self.args.s_wr_slp)
if sha_b64 != chash:
- m = "your chunk got corrupted somehow (received {} bytes); expected vs received hash:\n{}\n{}"
- raise Pebkac(400, m.format(post_sz, chash, sha_b64))
+ t = "your chunk got corrupted somehow (received {} bytes); expected vs received hash:\n{}\n{}"
+ raise Pebkac(400, t.format(post_sz, chash, sha_b64))
if len(cstart) > 1 and path != os.devnull:
self.log(
@@ -974,15 +1088,15 @@ class HttpCli(object):
with self.mutex:
self.u2fh.put(path, f)
finally:
- x = self.conn.hsrv.broker.put(True, "up2k.release_chunk", ptop, wark, chash)
+ x = self.conn.hsrv.broker.ask("up2k.release_chunk", ptop, wark, chash)
x.get() # block client until released
- x = self.conn.hsrv.broker.put(True, "up2k.confirm_chunk", ptop, wark, chash)
- x = x.get()
+ x = self.conn.hsrv.broker.ask("up2k.confirm_chunk", ptop, wark, chash)
+ ztis = x.get()
try:
- num_left, fin_path = x
+ num_left, fin_path = ztis
except:
- self.loud_reply(x, status=500)
+ self.loud_reply(ztis, status=500)
return False
if not num_left and fpool:
@@ -991,7 +1105,7 @@ class HttpCli(object):
# windows cant rename open files
if ANYWIN and path != fin_path and not self.args.nw:
- self.conn.hsrv.broker.put(True, "up2k.finish_upload", ptop, wark).get()
+ self.conn.hsrv.broker.ask("up2k.finish_upload", ptop, wark).get()
if not ANYWIN and not num_left:
times = (int(time.time()), int(lastmod))
@@ -1006,7 +1120,8 @@ class HttpCli(object):
self.reply(b"thank")
return True
- def handle_login(self):
+ def handle_login(self) -> bool:
+ assert self.parser
pwd = self.parser.require("cppwd", 64)
self.parser.drop()
@@ -1021,11 +1136,11 @@ class HttpCli(object):
dst += quotep(self.vpath)
ck, msg = self.get_pwd_cookie(pwd)
- html = self.j2("msg", h1=msg, h2='ack', redir=dst)
+ html = self.j2s("msg", h1=msg, h2='ack', redir=dst)
self.reply(html.encode("utf-8"), headers={"Set-Cookie": ck})
return True
- def get_pwd_cookie(self, pwd):
+ def get_pwd_cookie(self, pwd: str) -> tuple[str, str]:
if pwd in self.asrv.iacct:
msg = "login ok"
dur = int(60 * 60 * self.args.logout)
@@ -1038,9 +1153,10 @@ class HttpCli(object):
if self.is_ancient:
r = r.rsplit(" ", 1)[0]
- return [r, msg]
+ return r, msg
- def handle_mkdir(self):
+ def handle_mkdir(self) -> bool:
+ assert self.parser
new_dir = self.parser.require("name", 512)
self.parser.drop()
@@ -1075,7 +1191,8 @@ class HttpCli(object):
self.redirect(vpath)
return True
- def handle_new_md(self):
+ def handle_new_md(self) -> bool:
+ assert self.parser
new_file = self.parser.require("name", 512)
self.parser.drop()
@@ -1102,7 +1219,8 @@ class HttpCli(object):
self.redirect(vpath, "?edit")
return True
- def handle_plain_upload(self):
+ def handle_plain_upload(self) -> bool:
+ assert self.parser
nullwrite = self.args.nw
vfs, rem = self.asrv.vfs.get(self.vpath, self.uname, False, True)
self._assert_safe_rem(rem)
@@ -1116,17 +1234,21 @@ class HttpCli(object):
if not nullwrite:
bos.makedirs(fdir_base)
- files = []
+ files: list[tuple[int, str, str, str, str, str]] = []
+ # sz, sha_hex, sha_b64, p_file, fname, abspath
errmsg = ""
t0 = time.time()
try:
+ assert self.parser.gen
for nfile, (p_field, p_file, p_data) in enumerate(self.parser.gen):
if not p_file:
self.log("discarding incoming file without filename")
# fallthrough
fdir = fdir_base
- fname = sanitize_fn(p_file, "", [".prologue.html", ".epilogue.html"])
+ fname = sanitize_fn(
+ p_file or "", "", [".prologue.html", ".epilogue.html"]
+ )
if p_file and not nullwrite:
if not bos.path.isdir(fdir):
raise Pebkac(404, "that folder does not exist")
@@ -1143,8 +1265,8 @@ class HttpCli(object):
lim.chk_nup(self.ip)
try:
- with ren_open(fname, "wb", 512 * 1024, **open_args) as f:
- f, fname = f["orz"]
+ with ren_open(fname, "wb", 512 * 1024, **open_args) as zfw:
+ f, fname = zfw["orz"]
abspath = os.path.join(fdir, fname)
self.log("writing to {}".format(abspath))
sz, sha_hex, sha_b64 = hashcopy(p_data, f, self.args.s_wr_slp)
@@ -1158,12 +1280,14 @@ class HttpCli(object):
lim.chk_sz(sz)
except:
bos.unlink(abspath)
+ fname = os.devnull
raise
- files.append([sz, sha_hex, sha_b64, p_file, fname, abspath])
+ files.append(
+ (sz, sha_hex, sha_b64, p_file or "(discarded)", fname, abspath)
+ )
dbv, vrem = vfs.get_dbv(rem)
- self.conn.hsrv.broker.put(
- False,
+ self.conn.hsrv.broker.say(
"up2k.hash_file",
dbv.realpath,
dbv.flags,
@@ -1192,7 +1316,7 @@ class HttpCli(object):
except Pebkac as ex:
errmsg = vol_san(
- self.asrv.vfs.all_vols.values(), unicode(ex).encode("utf-8")
+ list(self.asrv.vfs.all_vols.values()), unicode(ex).encode("utf-8")
).decode("utf-8")
td = max(0.1, time.time() - t0)
@@ -1205,7 +1329,12 @@ class HttpCli(object):
status = "ERROR"
msg = "{} // {} bytes // {:.3f} MiB/s\n".format(status, sz_total, spd)
- jmsg = {"status": status, "sz": sz_total, "mbps": round(spd, 3), "files": []}
+ jmsg: dict[str, Any] = {
+ "status": status,
+ "sz": sz_total,
+ "mbps": round(spd, 3),
+ "files": [],
+ }
if errmsg:
msg += errmsg + "\n"
@@ -1260,17 +1389,17 @@ class HttpCli(object):
ft = "{}\n{}\n{}\n".format(ft, msg.rstrip(), errmsg)
f.write(ft.encode("utf-8"))
- status = 400 if errmsg else 200
+ sc = 400 if errmsg else 200
if "j" in self.uparam:
jtxt = json.dumps(jmsg, indent=2, sort_keys=True).encode("utf-8", "replace")
- self.reply(jtxt, mime="application/json", status=status)
+ self.reply(jtxt, mime="application/json", status=sc)
else:
self.redirect(
self.vpath,
msg=msg,
flavor="return to",
click=False,
- status=status,
+ status=sc,
)
if errmsg:
@@ -1279,7 +1408,8 @@ class HttpCli(object):
self.parser.drop()
return True
- def handle_text_upload(self):
+ def handle_text_upload(self) -> bool:
+ assert self.parser
try:
cli_lastmod3 = int(self.parser.require("lastmod", 16))
except:
@@ -1314,7 +1444,8 @@ class HttpCli(object):
self.reply(response.encode("utf-8"))
return True
- srv_lastmod = srv_lastmod3 = -1
+ srv_lastmod = -1.0
+ srv_lastmod3 = -1
try:
st = bos.stat(fp)
srv_lastmod = st.st_mtime
@@ -1329,7 +1460,7 @@ class HttpCli(object):
if not same_lastmod:
# some filesystems/transports limit precision to 1sec, hopefully floored
same_lastmod = (
- srv_lastmod == int(srv_lastmod)
+ srv_lastmod == int(cli_lastmod3 / 1000)
and cli_lastmod3 > srv_lastmod3
and cli_lastmod3 - srv_lastmod3 < 1000
)
@@ -1360,6 +1491,7 @@ class HttpCli(object):
pass
bos.rename(fp, os.path.join(mdir, ".hist", mfile2))
+ assert self.parser.gen
p_field, _, p_data = next(self.parser.gen)
if p_field != "body":
raise Pebkac(400, "expected body, got {}".format(p_field))
@@ -1388,7 +1520,7 @@ class HttpCli(object):
self.reply(response.encode("utf-8"))
return True
- def _chk_lastmod(self, file_ts):
+ def _chk_lastmod(self, file_ts: int) -> tuple[str, bool]:
file_lastmod = http_ts(file_ts)
cli_lastmod = self.headers.get("if-modified-since")
if cli_lastmod:
@@ -1408,7 +1540,7 @@ class HttpCli(object):
return file_lastmod, True
- def tx_file(self, req_path):
+ def tx_file(self, req_path: str) -> bool:
status = 200
logmsg = "{:4} {} ".format("", self.req)
logtail = ""
@@ -1417,7 +1549,7 @@ class HttpCli(object):
# if request is for foo.js, check if we have foo.js.{gz,br}
file_ts = 0
- editions = {}
+ editions: dict[str, tuple[str, int]] = {}
for ext in ["", ".gz", ".br"]:
try:
fs_path = req_path + ext
@@ -1425,8 +1557,8 @@ class HttpCli(object):
if stat.S_ISDIR(st.st_mode):
continue
- file_ts = max(file_ts, st.st_mtime)
- editions[ext or "plain"] = [fs_path, st.st_size]
+ file_ts = max(file_ts, int(st.st_mtime))
+ editions[ext or "plain"] = (fs_path, st.st_size)
except:
pass
if not self.vpath.startswith(".cpr/"):
@@ -1526,8 +1658,8 @@ class HttpCli(object):
use_sendfile = False
if decompress:
- open_func = gzip.open
- open_args = [fsenc(fs_path), "rb"]
+ open_func: Any = gzip.open
+ open_args: list[Any] = [fsenc(fs_path), "rb"]
# Content-Length := original file size
upper = gzip_orig_sz(fs_path)
else:
@@ -1551,7 +1683,7 @@ class HttpCli(object):
if "txt" in self.uparam:
mime = "text/plain; charset={}".format(self.uparam["txt"] or "utf-8")
elif "mime" in self.uparam:
- mime = self.uparam.get("mime")
+ mime = str(self.uparam.get("mime"))
else:
mime = guess_mime(req_path)
@@ -1583,19 +1715,18 @@ class HttpCli(object):
return ret
- def tx_zip(self, fmt, uarg, vn, rem, items, dots):
+ def tx_zip(
+ self, fmt: str, uarg: str, vn: VFS, rem: str, items: list[str], dots: bool
+ ) -> bool:
if self.args.no_zip:
raise Pebkac(400, "not enabled")
logmsg = "{:4} {} ".format("", self.req)
self.keepalive = False
- if not uarg:
- uarg = ""
-
if fmt == "tar":
mime = "application/x-tar"
- packer = StreamTar
+ packer: Type[StreamArc] = StreamTar
else:
mime = "application/zip"
packer = StreamZip
@@ -1609,24 +1740,25 @@ class HttpCli(object):
safe = (string.ascii_letters + string.digits).replace("%", "")
afn = "".join([x if x in safe.replace('"', "") else "_" for x in fn])
bascii = unicode(safe).encode("utf-8")
- ufn = fn.encode("utf-8", "xmlcharrefreplace")
- if PY2:
- ufn = [unicode(x) if x in bascii else "%{:02x}".format(ord(x)) for x in ufn]
- else:
- ufn = [
+ zb = fn.encode("utf-8", "xmlcharrefreplace")
+ if not PY2:
+ zbl = [
chr(x).encode("utf-8")
if x in bascii
else "%{:02x}".format(x).encode("ascii")
- for x in ufn
+ for x in zb
]
- ufn = b"".join(ufn).decode("ascii")
+ else:
+ zbl = [unicode(x) if x in bascii else "%{:02x}".format(ord(x)) for x in zb]
+
+ ufn = b"".join(zbl).decode("ascii")
cdis = "attachment; filename=\"{}.{}\"; filename*=UTF-8''{}.{}"
cdis = cdis.format(afn, fmt, ufn, fmt)
self.log(cdis)
self.send_headers(None, mime=mime, headers={"Content-Disposition": cdis})
- fgen = vn.zipgen(rem, items, self.uname, dots, not self.args.no_scandir)
+ fgen = vn.zipgen(rem, set(items), self.uname, dots, not self.args.no_scandir)
# for f in fgen: print(repr({k: f[k] for k in ["vp", "ap"]}))
bgen = packer(self.log, fgen, utf8="utf" in uarg, pre_crc="crc" in uarg)
bsent = 0
@@ -1645,7 +1777,7 @@ class HttpCli(object):
self.log("{}, {}".format(logmsg, spd))
return True
- def tx_ico(self, ext, exact=False):
+ def tx_ico(self, ext: str, exact: bool = False) -> bool:
self.permit_caching()
if ext.endswith("/"):
ext = "folder"
@@ -1674,7 +1806,7 @@ class HttpCli(object):
self.reply(ico, mime=mime, headers={"Last-Modified": lm})
return True
- def tx_md(self, fs_path):
+ def tx_md(self, fs_path: str) -> bool:
logmsg = "{:4} {} ".format("", self.req)
if not self.can_write:
@@ -1683,7 +1815,7 @@ class HttpCli(object):
tpl = "mde" if "edit2" in self.uparam else "md"
html_path = os.path.join(E.mod, "web", "{}.html".format(tpl))
- template = self.j2(tpl)
+ template = self.j2j(tpl)
st = bos.stat(fs_path)
ts_md = st.st_mtime
@@ -1694,7 +1826,7 @@ class HttpCli(object):
sz_md = 0
for buf in yieldfile(fs_path):
sz_md += len(buf)
- for c, v in [[b"&", 4], [b"<", 3], [b">", 3]]:
+ for c, v in [(b"&", 4), (b"<", 3), (b">", 3)]:
sz_md += (len(buf) - len(buf.replace(c, b""))) * v
file_ts = max(ts_md, ts_html, E.t0)
@@ -1720,8 +1852,8 @@ class HttpCli(object):
"md": boundary,
"arg_base": arg_base,
}
- html = template.render(**targs).encode("utf-8", "replace")
- html = html.split(boundary.encode("utf-8"))
+ zs = template.render(**targs).encode("utf-8", "replace")
+ html = zs.split(boundary.encode("utf-8"))
if len(html) != 2:
raise Exception("boundary appears in " + html_path)
@@ -1750,7 +1882,7 @@ class HttpCli(object):
return True
- def tx_mounts(self):
+ def tx_mounts(self) -> bool:
suf = self.urlq({}, ["h"])
avol = [x for x in self.wvol if x in self.rvol]
rvol, wvol, avol = [
@@ -1759,7 +1891,7 @@ class HttpCli(object):
]
if avol and not self.args.no_rescan:
- x = self.conn.hsrv.broker.put(True, "up2k.get_state")
+ x = self.conn.hsrv.broker.ask("up2k.get_state")
vs = json.loads(x.get())
vstate = {("/" + k).rstrip("/") + "/": v for k, v in vs["volstate"].items()}
else:
@@ -1787,11 +1919,11 @@ class HttpCli(object):
for v in wvol:
txt += "\n " + v
- txt = txt.encode("utf-8", "replace") + b"\n"
- self.reply(txt, mime="text/plain; charset=utf-8")
+ zb = txt.encode("utf-8", "replace") + b"\n"
+ self.reply(zb, mime="text/plain; charset=utf-8")
return True
- html = self.j2(
+ html = self.j2s(
"splash",
this=self,
qvpath=quotep(self.vpath),
@@ -1809,38 +1941,41 @@ class HttpCli(object):
self.reply(html.encode("utf-8"))
return True
- def set_k304(self):
+ def set_k304(self) -> bool:
ck = gencookie("k304", self.uparam["k304"], 60 * 60 * 24 * 365)
self.out_headerlist.append(("Set-Cookie", ck))
self.redirect("", "?h#cc")
+ return True
- def set_am_js(self):
+ def set_am_js(self) -> bool:
v = "n" if self.uparam["am_js"] == "n" else "y"
ck = gencookie("js", v, 60 * 60 * 24 * 365)
self.out_headerlist.append(("Set-Cookie", ck))
self.reply(b"promoted\n")
+ return True
- def set_cfg_reset(self):
+ def set_cfg_reset(self) -> bool:
for k in ("k304", "js", "cppwd"):
self.out_headerlist.append(("Set-Cookie", gencookie(k, "x", None)))
self.redirect("", "?h#cc")
+ return True
- def tx_404(self, is_403=False):
+ def tx_404(self, is_403: bool = False) -> bool:
rc = 404
if self.args.vague_403:
- m = '404 not found ┐( ´ -`)┌
or maybe you don\'t have access -- try logging in or go home
' + t = 'or maybe you don\'t have access -- try logging in or go home
' elif is_403: - m = 'you\'ll have to log in or go home
' + t = 'you\'ll have to log in or go home
' rc = 403 else: - m = '{}\n{}".format(time.time(), html_escape(alltrace())) self.reply(ret.encode("utf-8")) + return True - def tx_tree(self): + def tx_tree(self) -> bool: top = self.uparam["tree"] or "" dst = self.vpath if top in [".", ".."]: @@ -1898,12 +2034,12 @@ class HttpCli(object): dst = dst[len(top) + 1 :] ret = self.gen_tree(top, dst) - ret = json.dumps(ret) - self.reply(ret.encode("utf-8"), mime="application/json") + zs = json.dumps(ret) + self.reply(zs.encode("utf-8"), mime="application/json") return True - def gen_tree(self, top, target): - ret = {} + def gen_tree(self, top: str, target: str) -> dict[str, Any]: + ret: dict[str, Any] = {} excl = None if target: excl, target = (target.split("/", 1) + [""])[:2] @@ -1921,26 +2057,26 @@ class HttpCli(object): for v in self.rvol: d1, d2 = v.rsplit("/", 1) if "/" in v else ["", v] if d1 == top: - vfs_virt[d2] = 0 + vfs_virt[d2] = vn # typechk, value never read dirs = [] - vfs_ls = [x[0] for x in vfs_ls if stat.S_ISDIR(x[1].st_mode)] + dirnames = [x[0] for x in vfs_ls if stat.S_ISDIR(x[1].st_mode)] if not self.args.ed or "dots" not in self.uparam: - vfs_ls = exclude_dotfiles(vfs_ls) + dirnames = exclude_dotfiles(dirnames) - for fn in [x for x in vfs_ls if x != excl]: + for fn in [x for x in dirnames if x != excl]: dirs.append(quotep(fn)) - for x in vfs_virt.keys(): + for x in vfs_virt: if x != excl: dirs.append(x) ret["a"] = dirs return ret - def tx_ups(self): + def tx_ups(self) -> bool: if not self.args.unpost: raise Pebkac(400, "the unpost feature is disabled in server config") @@ -1952,7 +2088,7 @@ class HttpCli(object): lm = "ups [{}]".format(filt) self.log(lm) - ret = [] + ret: list[dict[str, Any]] = [] t0 = time.time() lim = time.time() - self.args.unpost for vol in self.asrv.vfs.all_vols.values(): @@ -1968,17 +2104,18 @@ class HttpCli(object): ret.append({"vp": quotep(vp), "sz": sz, "at": at}) if len(ret) > 3000: - ret.sort(key=lambda x: x["at"], reverse=True) + ret.sort(key=lambda x: x["at"], reverse=True) # type: ignore ret = ret[:2000] - ret.sort(key=lambda x: x["at"], reverse=True) + ret.sort(key=lambda x: x["at"], reverse=True) # type: ignore ret = ret[:2000] jtxt = json.dumps(ret, indent=2, sort_keys=True).encode("utf-8", "replace") self.log("{} #{} {:.2f}sec".format(lm, len(ret), time.time() - t0)) self.reply(jtxt, mime="application/json") + return True - def handle_rm(self, req=None): + def handle_rm(self, req: list[str]) -> bool: if not req and not self.can_delete: raise Pebkac(403, "not allowed for user " + self.uname) @@ -1988,10 +2125,11 @@ class HttpCli(object): if not req: req = [self.vpath] - x = self.conn.hsrv.broker.put(True, "up2k.handle_rm", self.uname, self.ip, req) + x = self.conn.hsrv.broker.ask("up2k.handle_rm", self.uname, self.ip, req) self.loud_reply(x.get()) + return True - def handle_mv(self): + def handle_mv(self) -> bool: if not self.can_move: raise Pebkac(403, "not allowed for user " + self.uname) @@ -2006,12 +2144,11 @@ class HttpCli(object): # x-www-form-urlencoded (url query part) uses # either + or %20 for 0x20 so handle both dst = unquotep(dst.replace("+", " ")) - x = self.conn.hsrv.broker.put( - True, "up2k.handle_mv", self.uname, self.vpath, dst - ) + x = self.conn.hsrv.broker.ask("up2k.handle_mv", self.uname, self.vpath, dst) self.loud_reply(x.get()) + return True - def tx_ls(self, ls): + def tx_ls(self, ls: dict[str, Any]) -> bool: dirs = ls["dirs"] files = ls["files"] arg = self.uparam["ls"] @@ -2055,17 +2192,17 @@ class HttpCli(object): x["name"] = n fmt = fmt.format(len(nfmt.format(biggest))) - ret = [ + retl = [ "# {}: {}".format(x, ls[x]) for x in ["acct", "perms", "srvinf"] if x in ls ] - ret += [ + retl += [ fmt.format(x["dt"], x["sz"], x["name"]) for y in [dirs, files] for x in y ] - ret = "\n".join(ret) + ret = "\n".join(retl) mime = "text/plain; charset=utf-8" else: [x.pop(k) for k in ["name", "dt"] for y in [dirs, files] for x in y] @@ -2076,7 +2213,7 @@ class HttpCli(object): self.reply(ret.encode("utf-8", "replace") + b"\n", mime=mime) return True - def tx_browser(self): + def tx_browser(self) -> bool: vpath = "" vpnodes = [["", "/"]] if self.vpath: @@ -2164,7 +2301,7 @@ class HttpCli(object): if WINDOWS: try: bfree = ctypes.c_ulonglong(0) - ctypes.windll.kernel32.GetDiskFreeSpaceExW( + ctypes.windll.kernel32.GetDiskFreeSpaceExW( # type: ignore ctypes.c_wchar_p(abspath), None, None, ctypes.pointer(bfree) ) srv_info.append(humansize(bfree.value) + " free") @@ -2179,7 +2316,7 @@ class HttpCli(object): except: pass - srv_info = " // ".join(srv_info) + srv_infot = " // ".join(srv_info) perms = [] if self.can_read: @@ -2223,7 +2360,7 @@ class HttpCli(object): "dirs": [], "files": [], "taglist": [], - "srvinf": srv_info, + "srvinf": srv_infot, "acct": self.uname, "idx": ("e2d" in vn.flags), "perms": perms, @@ -2251,7 +2388,7 @@ class HttpCli(object): "logues": logues, "readme": readme, "title": html_escape(self.vpath, crlf=True), - "srv_info": srv_info, + "srv_info": srv_infot, "lang": self.args.lang, "dtheme": self.args.theme, "themes": self.args.themes, @@ -2267,7 +2404,7 @@ class HttpCli(object): if "zip" in self.uparam or "tar" in self.uparam: raise Pebkac(403) - html = self.j2(tpl, **j2a) + html = self.j2s(tpl, **j2a) self.reply(html.encode("utf-8", "replace")) return True @@ -2280,11 +2417,12 @@ class HttpCli(object): rem, self.uname, not self.args.no_scandir, [[True], [False, True]] ) stats = {k: v for k, v in vfs_ls} - vfs_ls = [x[0] for x in vfs_ls] - vfs_ls.extend(vfs_virt.keys()) + ls_names = [x[0] for x in vfs_ls] + ls_names.extend(list(vfs_virt.keys())) # check for old versions of files, - hist = {} # [num-backups, most-recent, hist-path] + # [num-backups, most-recent, hist-path] + hist: dict[str, tuple[int, float, str]] = {} histdir = os.path.join(fsroot, ".hist") ptn = re.compile(r"(.*)\.([0-9]+\.[0-9]{3})(\.[^\.]+)$") try: @@ -2294,14 +2432,14 @@ class HttpCli(object): continue fn = m.group(1) + m.group(3) - n, ts, _ = hist.get(fn, [0, 0, ""]) - hist[fn] = [n + 1, max(ts, float(m.group(2))), hfn] + n, ts, _ = hist.get(fn, (0, 0, "")) + hist[fn] = (n + 1, max(ts, float(m.group(2))), hfn) except: pass # show dotfiles if permitted and requested if not self.args.ed or "dots" not in self.uparam: - vfs_ls = exclude_dotfiles(vfs_ls) + ls_names = exclude_dotfiles(ls_names) icur = None if "e2t" in vn.flags: @@ -2312,7 +2450,7 @@ class HttpCli(object): dirs = [] files = [] - for fn in vfs_ls: + for fn in ls_names: base = "" href = fn if not is_ls and not is_js and not self.trailing_slash and vpath: @@ -2339,14 +2477,14 @@ class HttpCli(object): margin = 'zip'.format(quotep(href)) elif fn in hist: margin = '#{}'.format( - base, html_escape(hist[fn][2], quote=True, crlf=True), hist[fn][0] + base, html_escape(hist[fn][2], quot=True, crlf=True), hist[fn][0] ) else: margin = "-" sz = inf.st_size - dt = datetime.utcfromtimestamp(inf.st_mtime) - dt = dt.strftime("%Y-%m-%d %H:%M:%S") + zd = datetime.utcfromtimestamp(inf.st_mtime) + dt = zd.strftime("%Y-%m-%d %H:%M:%S") try: ext = "---" if is_dir else fn.rsplit(".", 1)[1] @@ -2380,11 +2518,11 @@ class HttpCli(object): files.append(item) item["rd"] = rem - taglist = {} - for f in files: - fn = f["name"] - rd = f["rd"] - del f["rd"] + tagset: set[str] = set() + for fe in files: + fn = fe["name"] + rd = fe["rd"] + del fe["rd"] if not icur: break @@ -2403,12 +2541,12 @@ class HttpCli(object): args = s3enc(idx.mem_cur, rd, fn) r = icur.execute(q, args).fetchone() except: - m = "tag list error, {}/{}\n{}" - self.log(m.format(rd, fn, min_ex())) + t = "tag list error, {}/{}\n{}" + self.log(t.format(rd, fn, min_ex())) break - tags = {} - f["tags"] = tags + tags: dict[str, Any] = {} + fe["tags"] = tags if not r: continue @@ -2417,17 +2555,19 @@ class HttpCli(object): q = "select k, v from mt where w = ? and +k != 'x'" try: for k, v in icur.execute(q, (w,)): - taglist[k] = True + tagset.add(k) tags[k] = v except: - m = "tag read error, {}/{} [{}]:\n{}" - self.log(m.format(rd, fn, w, min_ex())) + t = "tag read error, {}/{} [{}]:\n{}" + self.log(t.format(rd, fn, w, min_ex())) break if icur: - taglist = [k for k in vn.flags.get("mte", "").split(",") if k in taglist] - for f in dirs: - f["tags"] = {} + taglist = [k for k in vn.flags.get("mte", "").split(",") if k in tagset] + for fe in dirs: + fe["tags"] = {} + else: + taglist = list(tagset) if is_ls: ls_ret["dirs"] = dirs @@ -2480,6 +2620,6 @@ class HttpCli(object): if self.args.css_browser: j2a["css"] = self.args.css_browser - html = self.j2(tpl, **j2a) + html = self.j2s(tpl, **j2a) self.reply(html.encode("utf-8", "replace")) return True diff --git a/copyparty/httpconn.py b/copyparty/httpconn.py index 85f5f701..067e2d31 100644 --- a/copyparty/httpconn.py +++ b/copyparty/httpconn.py @@ -1,25 +1,36 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import re +import argparse # typechk import os -import time +import re import socket +import threading # typechk +import time -HAVE_SSL = True try: + HAVE_SSL = True import ssl except: HAVE_SSL = False -from .__init__ import E -from .util import Unrecv +from . import util as Util +from .__init__ import TYPE_CHECKING, E +from .authsrv import AuthSrv # typechk from .httpcli import HttpCli -from .u2idx import U2idx +from .ico import Ico +from .mtag import HAVE_FFMPEG from .th_cli import ThumbCli from .th_srv import HAVE_PIL, HAVE_VIPS -from .mtag import HAVE_FFMPEG -from .ico import Ico +from .u2idx import U2idx + +try: + from typing import Optional, Pattern, Union +except: + pass + +if TYPE_CHECKING: + from .httpsrv import HttpSrv class HttpConn(object): @@ -28,32 +39,37 @@ class HttpConn(object): creates an HttpCli for each request (Connection: Keep-Alive) """ - def __init__(self, sck, addr, hsrv): + def __init__( + self, sck: socket.socket, addr: tuple[str, int], hsrv: "HttpSrv" + ) -> None: self.s = sck - self.sr = None # Type: Unrecv + self.sr: Optional[Util._Unrecv] = None self.addr = addr self.hsrv = hsrv - self.mutex = hsrv.mutex - self.args = hsrv.args - self.asrv = hsrv.asrv + self.mutex: threading.Lock = hsrv.mutex # mypy404 + self.args: argparse.Namespace = hsrv.args # mypy404 + self.asrv: AuthSrv = hsrv.asrv # mypy404 self.cert_path = hsrv.cert_path - self.u2fh = hsrv.u2fh + self.u2fh: Util.FHC = hsrv.u2fh # mypy404 enth = (HAVE_PIL or HAVE_VIPS or HAVE_FFMPEG) and not self.args.no_thumb - self.thumbcli = ThumbCli(hsrv) if enth else None - self.ico = Ico(self.args) + self.thumbcli: Optional[ThumbCli] = ThumbCli(hsrv) if enth else None # mypy404 + self.ico: Ico = Ico(self.args) # mypy404 - self.t0 = time.time() + self.t0: float = time.time() # mypy404 self.stopping = False - self.nreq = 0 - self.nbyte = 0 - self.u2idx = None - self.log_func = hsrv.log - self.lf_url = re.compile(self.args.lf_url) if self.args.lf_url else None + self.nreq: int = 0 # mypy404 + self.nbyte: int = 0 # mypy404 + self.u2idx: Optional[U2idx] = None + self.log_func: Util.RootLogger = hsrv.log # mypy404 + self.log_src: str = "httpconn" # mypy404 + self.lf_url: Optional[Pattern[str]] = ( + re.compile(self.args.lf_url) if self.args.lf_url else None + ) # mypy404 self.set_rproxy() - def shutdown(self): + def shutdown(self) -> None: self.stopping = True try: self.s.shutdown(socket.SHUT_RDWR) @@ -61,7 +77,7 @@ class HttpConn(object): except: pass - def set_rproxy(self, ip=None): + def set_rproxy(self, ip: Optional[str] = None) -> str: if ip is None: color = 36 ip = self.addr[0] @@ -74,35 +90,35 @@ class HttpConn(object): self.log_src = "{} \033[{}m{}".format(ip, color, self.addr[1]).ljust(26) return self.log_src - def respath(self, res_name): + def respath(self, res_name: str) -> str: return os.path.join(E.mod, "web", res_name) - def log(self, msg, c=0): + def log(self, msg: str, c: Union[int, str] = 0) -> None: self.log_func(self.log_src, msg, c) - def get_u2idx(self): + def get_u2idx(self) -> U2idx: if not self.u2idx: self.u2idx = U2idx(self) return self.u2idx - def _detect_https(self): + def _detect_https(self) -> bool: method = None if self.cert_path: try: method = self.s.recv(4, socket.MSG_PEEK) except socket.timeout: - return + return False except AttributeError: # jython does not support msg_peek; forget about https method = self.s.recv(4) - self.sr = Unrecv(self.s, self.log) + self.sr = Util.Unrecv(self.s, self.log) self.sr.buf = method # jython used to do this, they stopped since it's broken # but reimplementing sendall is out of scope for now if not getattr(self.s, "sendall", None): - self.s.sendall = self.s.send + self.s.sendall = self.s.send # type: ignore if len(method) != 4: err = "need at least 4 bytes in the first packet; got {}".format( @@ -112,17 +128,18 @@ class HttpConn(object): self.log(err) self.s.send(b"HTTP/1.1 400 Bad Request\r\n\r\n" + err.encode("utf-8")) - return + return False return method not in [None, b"GET ", b"HEAD", b"POST", b"PUT ", b"OPTI"] - def run(self): + def run(self) -> None: self.sr = None if self.args.https_only: is_https = True elif self.args.http_only or not HAVE_SSL: is_https = False else: + # raise Exception("asdf") is_https = self._detect_https() if is_https: @@ -151,14 +168,15 @@ class HttpConn(object): self.s = ctx.wrap_socket(self.s, server_side=True) msg = [ "\033[1;3{:d}m{}".format(c, s) - for c, s in zip([0, 5, 0], self.s.cipher()) + for c, s in zip([0, 5, 0], self.s.cipher()) # type: ignore ] self.log(" ".join(msg) + "\033[0m") if self.args.ssl_dbg and hasattr(self.s, "shared_ciphers"): - overlap = [y[::-1] for y in self.s.shared_ciphers()] - lines = [str(x) for x in (["TLS cipher overlap:"] + overlap)] - self.log("\n".join(lines)) + ciphers = self.s.shared_ciphers() + assert ciphers + overlap = [str(y[::-1]) for y in ciphers] + self.log("TLS cipher overlap:" + "\n".join(overlap)) for k, v in [ ["compression", self.s.compression()], ["ALPN proto", self.s.selected_alpn_protocol()], @@ -183,7 +201,7 @@ class HttpConn(object): return if not self.sr: - self.sr = Unrecv(self.s, self.log) + self.sr = Util.Unrecv(self.s, self.log) while not self.stopping: self.nreq += 1 diff --git a/copyparty/httpsrv.py b/copyparty/httpsrv.py index 04d2146c..fd449b5f 100644 --- a/copyparty/httpsrv.py +++ b/copyparty/httpsrv.py @@ -1,13 +1,15 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import os -import sys -import time -import math import base64 +import math +import os import socket +import sys import threading +import time + +import queue try: import jinja2 @@ -26,15 +28,18 @@ except ImportError: ) sys.exit(1) -from .__init__ import E, PY2, MACOS -from .util import FHC, spack, min_ex, start_stackmon, start_log_thrs +from .__init__ import MACOS, TYPE_CHECKING, E from .bos import bos from .httpconn import HttpConn +from .util import FHC, min_ex, spack, start_log_thrs, start_stackmon -if PY2: - import Queue as queue -else: - import queue +if TYPE_CHECKING: + from .broker_util import BrokerCli + +try: + from typing import Any, Optional +except: + pass class HttpSrv(object): @@ -43,7 +48,7 @@ class HttpSrv(object): relying on MpSrv for performance (HttpSrv is just plain threads) """ - def __init__(self, broker, nid): + def __init__(self, broker: "BrokerCli", nid: Optional[int]) -> None: self.broker = broker self.nid = nid self.args = broker.args @@ -58,17 +63,19 @@ class HttpSrv(object): self.tp_nthr = 0 # actual self.tp_ncli = 0 # fading - self.tp_time = None # latest worker collect - self.tp_q = None if self.args.no_htp else queue.LifoQueue() - self.t_periodic = None + self.tp_time = 0.0 # latest worker collect + self.tp_q: Optional[queue.LifoQueue[Any]] = ( + None if self.args.no_htp else queue.LifoQueue() + ) + self.t_periodic: Optional[threading.Thread] = None self.u2fh = FHC() - self.srvs = [] + self.srvs: list[socket.socket] = [] self.ncli = 0 # exact - self.clients = {} # laggy + self.clients: set[HttpConn] = set() # laggy self.nclimax = 0 - self.cb_ts = 0 - self.cb_v = 0 + self.cb_ts = 0.0 + self.cb_v = "" env = jinja2.Environment() env.loader = jinja2.FileSystemLoader(os.path.join(E.mod, "web")) @@ -82,7 +89,7 @@ class HttpSrv(object): if bos.path.exists(cert_path): self.cert_path = cert_path else: - self.cert_path = None + self.cert_path = "" if self.tp_q: self.start_threads(4) @@ -94,19 +101,19 @@ class HttpSrv(object): if self.args.log_thrs: start_log_thrs(self.log, self.args.log_thrs, nid) - self.th_cfg = {} # type: dict[str, Any] + self.th_cfg: dict[str, Any] = {} t = threading.Thread(target=self.post_init) t.daemon = True t.start() - def post_init(self): + def post_init(self) -> None: try: - x = self.broker.put(True, "thumbsrv.getcfg") + x = self.broker.ask("thumbsrv.getcfg") self.th_cfg = x.get() except: pass - def start_threads(self, n): + def start_threads(self, n: int) -> None: self.tp_nthr += n if self.args.log_htp: self.log(self.name, "workers += {} = {}".format(n, self.tp_nthr), 6) @@ -119,15 +126,16 @@ class HttpSrv(object): thr.daemon = True thr.start() - def stop_threads(self, n): + def stop_threads(self, n: int) -> None: self.tp_nthr -= n if self.args.log_htp: self.log(self.name, "workers -= {} = {}".format(n, self.tp_nthr), 6) + assert self.tp_q for _ in range(n): self.tp_q.put(None) - def periodic(self): + def periodic(self) -> None: while True: time.sleep(2 if self.tp_ncli or self.ncli else 10) with self.mutex: @@ -141,7 +149,7 @@ class HttpSrv(object): self.t_periodic = None return - def listen(self, sck, nlisteners): + def listen(self, sck: socket.socket, nlisteners: int) -> None: ip, port = sck.getsockname() self.srvs.append(sck) self.nclimax = math.ceil(self.args.nc * 1.0 / nlisteners) @@ -153,15 +161,15 @@ class HttpSrv(object): t.daemon = True t.start() - def thr_listen(self, srv_sck): + def thr_listen(self, srv_sck: socket.socket) -> None: """listens on a shared tcp server""" ip, port = srv_sck.getsockname() fno = srv_sck.fileno() msg = "subscribed @ {}:{} f{}".format(ip, port, fno) self.log(self.name, msg) - def fun(): - self.broker.put(False, "cb_httpsrv_up") + def fun() -> None: + self.broker.say("cb_httpsrv_up") threading.Thread(target=fun).start() @@ -185,21 +193,21 @@ class HttpSrv(object): continue if self.args.log_conn: - m = "|{}C-acc2 \033[0;36m{} \033[3{}m{}".format( + t = "|{}C-acc2 \033[0;36m{} \033[3{}m{}".format( "-" * 3, ip, port % 8, port ) - self.log("%s %s" % addr, m, c="1;30") + self.log("%s %s" % addr, t, c="1;30") self.accept(sck, addr) - def accept(self, sck, addr): + def accept(self, sck: socket.socket, addr: tuple[str, int]) -> None: """takes an incoming tcp connection and creates a thread to handle it""" now = time.time() if now - (self.tp_time or now) > 300: - m = "httpserver threadpool died: tpt {:.2f}, now {:.2f}, nthr {}, ncli {}" - self.log(self.name, m.format(self.tp_time, now, self.tp_nthr, self.ncli), 1) - self.tp_time = None + t = "httpserver threadpool died: tpt {:.2f}, now {:.2f}, nthr {}, ncli {}" + self.log(self.name, t.format(self.tp_time, now, self.tp_nthr, self.ncli), 1) + self.tp_time = 0 self.tp_q = None with self.mutex: @@ -209,10 +217,10 @@ class HttpSrv(object): if self.nid: name += "-{}".format(self.nid) - t = threading.Thread(target=self.periodic, name=name) - self.t_periodic = t - t.daemon = True - t.start() + thr = threading.Thread(target=self.periodic, name=name) + self.t_periodic = thr + thr.daemon = True + thr.start() if self.tp_q: self.tp_time = self.tp_time or now @@ -224,8 +232,8 @@ class HttpSrv(object): return if not self.args.no_htp: - m = "looks like the httpserver threadpool died; please make an issue on github and tell me the story of how you pulled that off, thanks and dog bless\n" - self.log(self.name, m, 1) + t = "looks like the httpserver threadpool died; please make an issue on github and tell me the story of how you pulled that off, thanks and dog bless\n" + self.log(self.name, t, 1) thr = threading.Thread( target=self.thr_client, @@ -235,14 +243,15 @@ class HttpSrv(object): thr.daemon = True thr.start() - def thr_poolw(self): + def thr_poolw(self) -> None: + assert self.tp_q while True: task = self.tp_q.get() if not task: break with self.mutex: - self.tp_time = None + self.tp_time = 0 try: sck, addr = task @@ -255,7 +264,7 @@ class HttpSrv(object): except: self.log(self.name, "thr_client: " + min_ex(), 3) - def shutdown(self): + def shutdown(self) -> None: self.stopping = True for srv in self.srvs: try: @@ -263,7 +272,7 @@ class HttpSrv(object): except: pass - clients = list(self.clients.keys()) + clients = list(self.clients) for cli in clients: try: cli.shutdown() @@ -279,13 +288,13 @@ class HttpSrv(object): self.log(self.name, "ok bye") - def thr_client(self, sck, addr): + def thr_client(self, sck: socket.socket, addr: tuple[str, int]) -> None: """thread managing one tcp client""" sck.settimeout(120) cli = HttpConn(sck, addr, self) with self.mutex: - self.clients[cli] = 0 + self.clients.add(cli) fno = sck.fileno() try: @@ -328,10 +337,10 @@ class HttpSrv(object): raise finally: with self.mutex: - del self.clients[cli] + self.clients.remove(cli) self.ncli -= 1 - def cachebuster(self): + def cachebuster(self) -> str: if time.time() - self.cb_ts < 1: return self.cb_v diff --git a/copyparty/ico.py b/copyparty/ico.py index 58076c89..f403e4b5 100644 --- a/copyparty/ico.py +++ b/copyparty/ico.py @@ -1,28 +1,28 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import hashlib +import argparse # typechk import colorsys +import hashlib from .__init__ import PY2 class Ico(object): - def __init__(self, args): + def __init__(self, args: argparse.Namespace) -> None: self.args = args - def get(self, ext, as_thumb): + def get(self, ext: str, as_thumb: bool) -> tuple[str, bytes]: """placeholder to make thumbnails not break""" - h = hashlib.md5(ext.encode("utf-8")).digest()[:2] + zb = hashlib.md5(ext.encode("utf-8")).digest()[:2] if PY2: - h = [ord(x) for x in h] + zb = [ord(x) for x in zb] - c1 = colorsys.hsv_to_rgb(h[0] / 256.0, 1, 0.3) - c2 = colorsys.hsv_to_rgb(h[0] / 256.0, 1, 1) - c = list(c1) + list(c2) - c = [int(x * 255) for x in c] - c = "".join(["{:02x}".format(x) for x in c]) + c1 = colorsys.hsv_to_rgb(zb[0] / 256.0, 1, 0.3) + c2 = colorsys.hsv_to_rgb(zb[0] / 256.0, 1, 1) + ci = [int(x * 255) for x in list(c1) + list(c2)] + c = "".join(["{:02x}".format(x) for x in ci]) h = 30 if not self.args.th_no_crop and as_thumb: @@ -37,6 +37,6 @@ class Ico(object): fill="#{}" font-family="monospace" font-size="14px" style="letter-spacing:.5px">{} """ - svg = svg.format(h, c[:6], c[6:], ext).encode("utf-8") + svg = svg.format(h, c[:6], c[6:], ext) - return ["image/svg+xml", svg] + return "image/svg+xml", svg.encode("utf-8") diff --git a/copyparty/mtag.py b/copyparty/mtag.py index 60a7584f..dd1e2ec3 100644 --- a/copyparty/mtag.py +++ b/copyparty/mtag.py @@ -1,18 +1,26 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import os -import sys +import argparse import json +import os import shutil import subprocess as sp +import sys from .__init__ import PY2, WINDOWS, unicode -from .util import fsenc, uncyg, runcmd, retchk, REKOBO_LKEY from .bos import bos +from .util import REKOBO_LKEY, fsenc, retchk, runcmd, uncyg + +try: + from typing import Any, Union + + from .util import RootLogger +except: + pass -def have_ff(cmd): +def have_ff(cmd: str) -> bool: if PY2: print("# checking {}".format(cmd)) cmd = (cmd + " -version").encode("ascii").split(b" ") @@ -30,7 +38,7 @@ HAVE_FFPROBE = have_ff("ffprobe") class MParser(object): - def __init__(self, cmdline): + def __init__(self, cmdline: str) -> None: self.tag, args = cmdline.split("=", 1) self.tags = self.tag.split(",") @@ -73,7 +81,9 @@ class MParser(object): raise Exception() -def ffprobe(abspath, timeout=10): +def ffprobe( + abspath: str, timeout: int = 10 +) -> tuple[dict[str, tuple[int, Any]], dict[str, list[Any]]]: cmd = [ b"ffprobe", b"-hide_banner", @@ -87,15 +97,15 @@ def ffprobe(abspath, timeout=10): return parse_ffprobe(so) -def parse_ffprobe(txt): +def parse_ffprobe(txt: str) -> tuple[dict[str, tuple[int, Any]], dict[str, list[Any]]]: """ffprobe -show_format -show_streams""" streams = [] fmt = {} g = {} for ln in [x.rstrip("\r") for x in txt.split("\n")]: try: - k, v = ln.split("=", 1) - g[k] = v + sk, sv = ln.split("=", 1) + g[sk] = sv continue except: pass @@ -109,8 +119,8 @@ def parse_ffprobe(txt): fmt = g streams = [fmt] + streams - ret = {} # processed - md = {} # raw tags + ret: dict[str, Any] = {} # processed + md: dict[str, list[Any]] = {} # raw tags is_audio = fmt.get("format_name") in ["mp3", "ogg", "flac", "wav"] if fmt.get("filename", "").split(".")[-1].lower() in ["m4a", "aac"]: @@ -161,43 +171,43 @@ def parse_ffprobe(txt): kvm = [["duration", ".dur"], ["bit_rate", ".q"]] for sk, rk in kvm: - v = strm.get(sk) - if v is None: + v1 = strm.get(sk) + if v1 is None: continue if rk.startswith("."): try: - v = float(v) + zf = float(v1) v2 = ret.get(rk) - if v2 is None or v > v2: - ret[rk] = v + if v2 is None or zf > v2: + ret[rk] = zf except: # sqlite doesnt care but the code below does - if v not in ["N/A"]: - ret[rk] = v + if v1 not in ["N/A"]: + ret[rk] = v1 else: - ret[rk] = v + ret[rk] = v1 if ret.get("vc") == "ansi": # shellscript return {}, {} for strm in streams: - for k, v in strm.items(): - if not k.startswith("TAG:"): + for sk, sv in strm.items(): + if not sk.startswith("TAG:"): continue - k = k[4:].strip() - v = v.strip() - if k and v and k not in md: - md[k] = [v] + sk = sk[4:].strip() + sv = sv.strip() + if sk and sv and sk not in md: + md[sk] = [sv] - for k in [".q", ".vq", ".aq"]: - if k in ret: - ret[k] /= 1000 # bit_rate=320000 + for sk in [".q", ".vq", ".aq"]: + if sk in ret: + ret[sk] /= 1000 # bit_rate=320000 - for k in [".q", ".vq", ".aq", ".resw", ".resh"]: - if k in ret: - ret[k] = int(ret[k]) + for sk in [".q", ".vq", ".aq", ".resw", ".resh"]: + if sk in ret: + ret[sk] = int(ret[sk]) if ".fps" in ret: fps = ret[".fps"] @@ -219,13 +229,13 @@ def parse_ffprobe(txt): if ".resw" in ret and ".resh" in ret: ret["res"] = "{}x{}".format(ret[".resw"], ret[".resh"]) - ret = {k: [0, v] for k, v in ret.items()} + zd = {k: (0, v) for k, v in ret.items()} - return ret, md + return zd, md class MTag(object): - def __init__(self, log_func, args): + def __init__(self, log_func: RootLogger, args: argparse.Namespace) -> None: self.log_func = log_func self.args = args self.usable = True @@ -242,7 +252,7 @@ class MTag(object): if self.backend == "mutagen": self.get = self.get_mutagen try: - import mutagen + import mutagen # noqa: F401 # pylint: disable=unused-import,import-outside-toplevel except: self.log("could not load Mutagen, trying FFprobe instead", c=3) self.backend = "ffprobe" @@ -339,31 +349,33 @@ class MTag(object): } # self.get = self.compare - def log(self, msg, c=0): + def log(self, msg: str, c: Union[int, str] = 0) -> None: self.log_func("mtag", msg, c) - def normalize_tags(self, ret, md): - for k, v in dict(md).items(): - if not v: + def normalize_tags( + self, parser_output: dict[str, tuple[int, Any]], md: dict[str, list[Any]] + ) -> dict[str, Union[str, float]]: + for sk, tv in dict(md).items(): + if not tv: continue - k = k.lower().split("::")[0].strip() - mk = self.rmap.get(k) - if not mk: + sk = sk.lower().split("::")[0].strip() + key_mapping = self.rmap.get(sk) + if not key_mapping: continue - pref, mk = mk - if mk not in ret or ret[mk][0] > pref: - ret[mk] = [pref, v[0]] + priority, alias = key_mapping + if alias not in parser_output or parser_output[alias][0] > priority: + parser_output[alias] = (priority, tv[0]) - # take first value - ret = {k: unicode(v[1]).strip() for k, v in ret.items()} + # take first value (lowest priority / most preferred) + ret = {sk: unicode(tv[1]).strip() for sk, tv in parser_output.items()} # track 3/7 => track 3 - for k, v in ret.items(): - if k[0] == ".": - v = v.split("/")[0].strip().lstrip("0") - ret[k] = v or 0 + for sk, tv in ret.items(): + if sk[0] == ".": + sv = str(tv).split("/")[0].strip().lstrip("0") + ret[sk] = sv or 0 # normalize key notation to rkeobo okey = ret.get("key") @@ -373,7 +385,7 @@ class MTag(object): return ret - def compare(self, abspath): + def compare(self, abspath: str) -> dict[str, Union[str, float]]: if abspath.endswith(".au"): return {} @@ -411,7 +423,7 @@ class MTag(object): return r1 - def get_mutagen(self, abspath): + def get_mutagen(self, abspath: str) -> dict[str, Union[str, float]]: if not bos.path.isfile(abspath): return {} @@ -425,7 +437,7 @@ class MTag(object): return self.get_ffprobe(abspath) if self.can_ffprobe else {} sz = bos.path.getsize(abspath) - ret = {".q": [0, int((sz / md.info.length) / 128)]} + ret = {".q": (0, int((sz / md.info.length) / 128))} for attr, k, norm in [ ["codec", "ac", unicode], @@ -456,24 +468,24 @@ class MTag(object): if k == "ac" and v.startswith("mp4a.40."): v = "aac" - ret[k] = [0, norm(v)] + ret[k] = (0, norm(v)) return self.normalize_tags(ret, md) - def get_ffprobe(self, abspath): + def get_ffprobe(self, abspath: str) -> dict[str, Union[str, float]]: if not bos.path.isfile(abspath): return {} ret, md = ffprobe(abspath) return self.normalize_tags(ret, md) - def get_bin(self, parsers, abspath): + def get_bin(self, parsers: dict[str, MParser], abspath: str) -> dict[str, Any]: if not bos.path.isfile(abspath): return {} pypath = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) - pypath = [str(pypath)] + [str(x) for x in sys.path if x] - pypath = str(os.pathsep.join(pypath)) + zsl = [str(pypath)] + [str(x) for x in sys.path if x] + pypath = str(os.pathsep.join(zsl)) env = os.environ.copy() env["PYTHONPATH"] = pypath @@ -491,9 +503,9 @@ class MTag(object): else: cmd = ["nice"] + cmd - cmd = [fsenc(x) for x in cmd] - rc, v, err = runcmd(cmd, **args) - retchk(rc, cmd, err, self.log, 5, self.args.mtag_v) + bcmd = [fsenc(x) for x in cmd] + rc, v, err = runcmd(bcmd, **args) # type: ignore + retchk(rc, bcmd, err, self.log, 5, self.args.mtag_v) v = v.strip() if not v: continue @@ -501,10 +513,10 @@ class MTag(object): if "," not in tagname: ret[tagname] = v else: - v = json.loads(v) + zj = json.loads(v) for tag in tagname.split(","): - if tag and tag in v: - ret[tag] = v[tag] + if tag and tag in zj: + ret[tag] = zj[tag] except: pass diff --git a/copyparty/star.py b/copyparty/star.py index 804ad724..21c2703c 100644 --- a/copyparty/star.py +++ b/copyparty/star.py @@ -4,20 +4,29 @@ from __future__ import print_function, unicode_literals import tarfile import threading -from .sutil import errdesc -from .util import Queue, fsenc, min_ex +from queue import Queue + from .bos import bos +from .sutil import StreamArc, errdesc +from .util import fsenc, min_ex + +try: + from typing import Any, Generator, Optional + + from .util import NamedLogger +except: + pass -class QFile(object): +class QFile(object): # inherit io.StringIO for painful typing """file-like object which buffers writes into a queue""" - def __init__(self): - self.q = Queue(64) - self.bq = [] + def __init__(self) -> None: + self.q: Queue[Optional[bytes]] = Queue(64) + self.bq: list[bytes] = [] self.nq = 0 - def write(self, buf): + def write(self, buf: Optional[bytes]) -> None: if buf is None or self.nq >= 240 * 1024: self.q.put(b"".join(self.bq)) self.bq = [] @@ -30,27 +39,32 @@ class QFile(object): self.nq += len(buf) -class StreamTar(object): +class StreamTar(StreamArc): """construct in-memory tar file from the given path""" - def __init__(self, log, fgen, **kwargs): + def __init__( + self, + log: NamedLogger, + fgen: Generator[dict[str, Any], None, None], + **kwargs: Any + ): + super(StreamTar, self).__init__(log, fgen) + self.ci = 0 self.co = 0 self.qfile = QFile() - self.log = log - self.fgen = fgen - self.errf = None + self.errf: dict[str, Any] = {} # python 3.8 changed to PAX_FORMAT as default, # waste of space and don't care about the new features fmt = tarfile.GNU_FORMAT - self.tar = tarfile.open(fileobj=self.qfile, mode="w|", format=fmt) + self.tar = tarfile.open(fileobj=self.qfile, mode="w|", format=fmt) # type: ignore w = threading.Thread(target=self._gen, name="star-gen") w.daemon = True w.start() - def gen(self): + def gen(self) -> Generator[Optional[bytes], None, None]: while True: buf = self.qfile.q.get() if not buf: @@ -63,7 +77,7 @@ class StreamTar(object): if self.errf: bos.unlink(self.errf["ap"]) - def ser(self, f): + def ser(self, f: dict[str, Any]) -> None: name = f["vp"] src = f["ap"] fsi = f["st"] @@ -76,21 +90,21 @@ class StreamTar(object): inf.gid = 0 self.ci += inf.size - with open(fsenc(src), "rb", 512 * 1024) as f: - self.tar.addfile(inf, f) + with open(fsenc(src), "rb", 512 * 1024) as fo: + self.tar.addfile(inf, fo) - def _gen(self): + def _gen(self) -> None: errors = [] for f in self.fgen: if "err" in f: - errors.append([f["vp"], f["err"]]) + errors.append((f["vp"], f["err"])) continue try: self.ser(f) except: ex = min_ex(5, True).replace("\n", "\n-- ") - errors.append([f["vp"], ex]) + errors.append((f["vp"], ex)) if errors: self.errf, txt = errdesc(errors) diff --git a/copyparty/stolen/surrogateescape.py b/copyparty/stolen/surrogateescape.py index 4b06ed28..b1ff8886 100644 --- a/copyparty/stolen/surrogateescape.py +++ b/copyparty/stolen/surrogateescape.py @@ -12,23 +12,28 @@ Original source: misc/python/surrogateescape.py in https://bitbucket.org/haypo/m # This code is released under the Python license and the BSD 2-clause license -import platform import codecs +import platform import sys PY3 = sys.version_info[0] > 2 WINDOWS = platform.system() == "Windows" FS_ERRORS = "surrogateescape" +try: + from typing import Any +except: + pass -def u(text): + +def u(text: Any) -> str: if PY3: return text else: return text.decode("unicode_escape") -def b(data): +def b(data: Any) -> bytes: if PY3: return data.encode("latin1") else: @@ -43,7 +48,7 @@ else: bytes_chr = chr -def surrogateescape_handler(exc): +def surrogateescape_handler(exc: Any) -> tuple[str, int]: """ Pure Python implementation of the PEP 383: the "surrogateescape" error handler of Python 3. Undecodable bytes will be replaced by a Unicode @@ -74,7 +79,7 @@ class NotASurrogateError(Exception): pass -def replace_surrogate_encode(mystring): +def replace_surrogate_encode(mystring: str) -> str: """ Returns a (unicode) string, not the more logical bytes, because the codecs register_error functionality expects this. @@ -100,7 +105,7 @@ def replace_surrogate_encode(mystring): return str().join(decoded) -def replace_surrogate_decode(mybytes): +def replace_surrogate_decode(mybytes: bytes) -> str: """ Returns a (unicode) string """ @@ -121,7 +126,7 @@ def replace_surrogate_decode(mybytes): return str().join(decoded) -def encodefilename(fn): +def encodefilename(fn: str) -> bytes: if FS_ENCODING == "ascii": # ASCII encoder of Python 2 expects that the error handler returns a # Unicode string encodable to ASCII, whereas our surrogateescape error @@ -161,7 +166,7 @@ def encodefilename(fn): return fn.encode(FS_ENCODING, FS_ERRORS) -def decodefilename(fn): +def decodefilename(fn: bytes) -> str: return fn.decode(FS_ENCODING, FS_ERRORS) @@ -181,7 +186,7 @@ if WINDOWS and not PY3: FS_ENCODING = codecs.lookup(FS_ENCODING).name -def register_surrogateescape(): +def register_surrogateescape() -> None: """ Registers the surrogateescape error handler on Python 2 (only) """ diff --git a/copyparty/sutil.py b/copyparty/sutil.py index 22de0066..506e389f 100644 --- a/copyparty/sutil.py +++ b/copyparty/sutil.py @@ -6,8 +6,29 @@ from datetime import datetime from .bos import bos +try: + from typing import Any, Generator, Optional -def errdesc(errors): + from .util import NamedLogger +except: + pass + + +class StreamArc(object): + def __init__( + self, + log: NamedLogger, + fgen: Generator[dict[str, Any], None, None], + **kwargs: Any + ): + self.log = log + self.fgen = fgen + + def gen(self) -> Generator[Optional[bytes], None, None]: + pass + + +def errdesc(errors: list[tuple[str, str]]) -> tuple[dict[str, Any], list[str]]: report = ["copyparty failed to add the following files to the archive:", ""] for fn, err in errors: diff --git a/copyparty/svchub.py b/copyparty/svchub.py index 44680513..e9b5aadf 100644 --- a/copyparty/svchub.py +++ b/copyparty/svchub.py @@ -1,41 +1,51 @@ # coding: utf-8 from __future__ import print_function, unicode_literals +import argparse +import calendar import os -import sys -import time import shlex -import string import signal import socket +import string +import sys import threading +import time from datetime import datetime, timedelta -import calendar -from .__init__ import E, PY2, WINDOWS, ANYWIN, MACOS, VT100, unicode -from .util import mp, start_log_thrs, start_stackmon, min_ex, ansi_re +try: + from types import FrameType + + import typing + from typing import Optional, Union +except: + pass + +from .__init__ import ANYWIN, MACOS, PY2, VT100, WINDOWS, E, unicode from .authsrv import AuthSrv -from .tcpsrv import TcpSrv -from .up2k import Up2k -from .th_srv import ThumbSrv, HAVE_PIL, HAVE_VIPS, HAVE_WEBP from .mtag import HAVE_FFMPEG, HAVE_FFPROBE +from .tcpsrv import TcpSrv +from .th_srv import HAVE_PIL, HAVE_VIPS, HAVE_WEBP, ThumbSrv +from .up2k import Up2k +from .util import ansi_re, min_ex, mp, start_log_thrs, start_stackmon class SvcHub(object): """ Hosts all services which cannot be parallelized due to reliance on monolithic resources. Creates a Broker which does most of the heavy stuff; hosted services can use this to perform work: - hub.broker.put(want_reply, destination, args_list). + hub.broker.(destination, args_list). Either BrokerThr (plain threads) or BrokerMP (multiprocessing) is used depending on configuration. Nothing is returned synchronously; if you want any value returned from the call, put() can return a queue (if want_reply=True) which has a blocking get() with the response. """ - def __init__(self, args, argv, printed): + def __init__(self, args: argparse.Namespace, argv: list[str], printed: str) -> None: self.args = args self.argv = argv - self.logf = None + self.logf: Optional[typing.TextIO] = None + self.logf_base_fn = "" self.stop_req = False self.reload_req = False self.stopping = False @@ -59,16 +69,16 @@ class SvcHub(object): if not args.use_fpool and args.j != 1: args.no_fpool = True - m = "multithreading enabled with -j {}, so disabling fpool -- this can reduce upload performance on some filesystems" - self.log("root", m.format(args.j)) + t = "multithreading enabled with -j {}, so disabling fpool -- this can reduce upload performance on some filesystems" + self.log("root", t.format(args.j)) if not args.no_fpool and args.j != 1: - m = "WARNING: --use-fpool combined with multithreading is untested and can probably cause undefined behavior" + t = "WARNING: --use-fpool combined with multithreading is untested and can probably cause undefined behavior" if ANYWIN: - m = 'windows cannot do multithreading without --no-fpool, so enabling that -- note that upload performance will suffer if you have microsoft defender "real-time protection" enabled, so you probably want to use -j 1 instead' + t = 'windows cannot do multithreading without --no-fpool, so enabling that -- note that upload performance will suffer if you have microsoft defender "real-time protection" enabled, so you probably want to use -j 1 instead' args.no_fpool = True - self.log("root", m, c=3) + self.log("root", t, c=3) bri = "zy"[args.theme % 2 :][:1] ch = "abcdefghijklmnopqrstuvwx"[int(args.theme / 2)] @@ -96,8 +106,8 @@ class SvcHub(object): self.args.th_dec = list(decs.keys()) self.thumbsrv = None if not args.no_thumb: - m = "decoder preference: {}".format(", ".join(self.args.th_dec)) - self.log("thumb", m) + t = "decoder preference: {}".format(", ".join(self.args.th_dec)) + self.log("thumb", t) if "pil" in self.args.th_dec and not HAVE_WEBP: msg = "disabling webp thumbnails because either libwebp is not available or your Pillow is too old" @@ -131,11 +141,11 @@ class SvcHub(object): if self.check_mp_enable(): from .broker_mp import BrokerMp as Broker else: - from .broker_thr import BrokerThr as Broker + from .broker_thr import BrokerThr as Broker # type: ignore self.broker = Broker(self) - def thr_httpsrv_up(self): + def thr_httpsrv_up(self) -> None: time.sleep(1 if self.args.ign_ebind_all else 5) expected = self.broker.num_workers * self.tcpsrv.nsrv failed = expected - self.httpsrv_up @@ -145,20 +155,20 @@ class SvcHub(object): if self.args.ign_ebind_all: if not self.tcpsrv.srv: for _ in range(self.broker.num_workers): - self.broker.put(False, "cb_httpsrv_up") + self.broker.say("cb_httpsrv_up") return if self.args.ign_ebind and self.tcpsrv.srv: return - m = "{}/{} workers failed to start" - m = m.format(failed, expected) - self.log("root", m, 1) + t = "{}/{} workers failed to start" + t = t.format(failed, expected) + self.log("root", t, 1) self.retcode = 1 os.kill(os.getpid(), signal.SIGTERM) - def cb_httpsrv_up(self): + def cb_httpsrv_up(self) -> None: self.httpsrv_up += 1 if self.httpsrv_up != self.broker.num_workers: return @@ -171,9 +181,9 @@ class SvcHub(object): thr.daemon = True thr.start() - def _logname(self): + def _logname(self) -> str: dt = datetime.utcnow() - fn = self.args.lo + fn = str(self.args.lo) for fs in "YmdHMS": fs = "%" + fs if fs in fn: @@ -181,7 +191,7 @@ class SvcHub(object): return fn - def _setup_logfile(self, printed): + def _setup_logfile(self, printed: str) -> None: base_fn = fn = sel_fn = self._logname() if fn != self.args.lo: ctr = 0 @@ -203,8 +213,6 @@ class SvcHub(object): lh = codecs.open(fn, "w", encoding="utf-8", errors="replace") - lh.base_fn = base_fn - argv = [sys.executable] + self.argv if hasattr(shlex, "quote"): argv = [shlex.quote(x) for x in argv] @@ -215,9 +223,10 @@ class SvcHub(object): printed += msg lh.write("t0: {:.3f}\nargv: {}\n\n{}".format(E.t0, " ".join(argv), printed)) self.logf = lh + self.logf_base_fn = base_fn print(msg, end="") - def run(self): + def run(self) -> None: self.tcpsrv.run() thr = threading.Thread(target=self.thr_httpsrv_up) @@ -252,7 +261,7 @@ class SvcHub(object): else: self.stop_thr() - def reload(self): + def reload(self) -> str: if self.reloading: return "cannot reload; already in progress" @@ -262,7 +271,7 @@ class SvcHub(object): t.start() return "reload initiated" - def _reload(self): + def _reload(self) -> None: self.log("root", "reload scheduled") with self.up2k.mutex: self.asrv.reload() @@ -271,7 +280,7 @@ class SvcHub(object): self.reloading = False - def stop_thr(self): + def stop_thr(self) -> None: while not self.stop_req: with self.stop_cond: self.stop_cond.wait(9001) @@ -282,7 +291,7 @@ class SvcHub(object): self.shutdown() - def signal_handler(self, sig, frame): + def signal_handler(self, sig: int, frame: Optional[FrameType]) -> None: if self.stopping: return @@ -294,7 +303,7 @@ class SvcHub(object): with self.stop_cond: self.stop_cond.notify_all() - def shutdown(self): + def shutdown(self) -> None: if self.stopping: return @@ -337,7 +346,7 @@ class SvcHub(object): sys.exit(ret) - def _log_disabled(self, src, msg, c=0): + def _log_disabled(self, src: str, msg: str, c: Union[int, str] = 0) -> None: if not self.logf: return @@ -349,8 +358,8 @@ class SvcHub(object): if now >= self.next_day: self._set_next_day() - def _set_next_day(self): - if self.next_day and self.logf and self.logf.base_fn != self._logname(): + def _set_next_day(self) -> None: + if self.next_day and self.logf and self.logf_base_fn != self._logname(): self.logf.close() self._setup_logfile("") @@ -364,7 +373,7 @@ class SvcHub(object): dt = dt.replace(hour=0, minute=0, second=0) self.next_day = calendar.timegm(dt.utctimetuple()) - def _log_enabled(self, src, msg, c=0): + def _log_enabled(self, src: str, msg: str, c: Union[int, str] = 0) -> None: """handles logging from all components""" with self.log_mutex: now = time.time() @@ -401,7 +410,7 @@ class SvcHub(object): if self.logf: self.logf.write(msg) - def check_mp_support(self): + def check_mp_support(self) -> str: vmin = sys.version_info[1] if WINDOWS: msg = "need python 3.3 or newer for multiprocessing;" @@ -415,16 +424,16 @@ class SvcHub(object): return msg try: - x = mp.Queue(1) - x.put(["foo", "bar"]) + x: mp.Queue[tuple[str, str]] = mp.Queue(1) + x.put(("foo", "bar")) if x.get()[0] != "foo": raise Exception() except: return "multiprocessing is not supported on your platform;" - return None + return "" - def check_mp_enable(self): + def check_mp_enable(self) -> bool: if self.args.j == 1: return False @@ -447,18 +456,18 @@ class SvcHub(object): self.log("svchub", "cannot efficiently use multiple CPU cores") return False - def sd_notify(self): + def sd_notify(self) -> None: try: - addr = os.getenv("NOTIFY_SOCKET") - if not addr: + zb = os.getenv("NOTIFY_SOCKET") + if not zb: return - addr = unicode(addr) + addr = unicode(zb) if addr.startswith("@"): addr = "\0" + addr[1:] - m = "".join(x for x in addr if x in string.printable) - self.log("sd_notify", m) + t = "".join(x for x in addr if x in string.printable) + self.log("sd_notify", t) sck = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) sck.connect(addr) diff --git a/copyparty/szip.py b/copyparty/szip.py index 177ff3ea..178f8a61 100644 --- a/copyparty/szip.py +++ b/copyparty/szip.py @@ -1,16 +1,23 @@ # coding: utf-8 from __future__ import print_function, unicode_literals +import calendar import time import zlib -import calendar -from .sutil import errdesc -from .util import yieldfile, sanitize_fn, spack, sunpack, min_ex from .bos import bos +from .sutil import StreamArc, errdesc +from .util import min_ex, sanitize_fn, spack, sunpack, yieldfile + +try: + from typing import Any, Generator, Optional + + from .util import NamedLogger +except: + pass -def dostime2unix(buf): +def dostime2unix(buf: bytes) -> int: t, d = sunpack(b" bytes: tt = time.gmtime(ts + 1) dy, dm, dd, th, tm, ts = list(tt)[:6] @@ -41,14 +48,22 @@ def unixtime2dos(ts): return b"\x00\x00\x21\x00" -def gen_fdesc(sz, crc32, z64): +def gen_fdesc(sz: int, crc32: int, z64: bool) -> bytes: ret = b"\x50\x4b\x07\x08" fmt = b" bytes: """ does regular file headers and the central directory meme if h_pos is set @@ -67,8 +82,8 @@ def gen_hdr(h_pos, fn, sz, lastmod, utf8, crc32, pre_crc): # confusingly this doesn't bump if h_pos req_ver = b"\x2d\x00" if z64 else b"\x0a\x00" - if crc32: - crc32 = spack(b" tuple[bytes, bool]: """ summary of all file headers, usually the zipfile footer unless something clamps @@ -154,10 +171,12 @@ def gen_ecdr(items, cdir_pos, cdir_end): # 2b comment length ret += b"\x00\x00" - return [ret, need_64] + return ret, need_64 -def gen_ecdr64(items, cdir_pos, cdir_end): +def gen_ecdr64( + items: list[tuple[str, int, int, int, int]], cdir_pos: int, cdir_end: int +) -> bytes: """ z64 end of central directory added when numfiles or a headerptr clamps @@ -181,7 +200,7 @@ def gen_ecdr64(items, cdir_pos, cdir_end): return ret -def gen_ecdr64_loc(ecdr64_pos): +def gen_ecdr64_loc(ecdr64_pos: int) -> bytes: """ z64 end of central directory locator points to ecdr64 @@ -196,21 +215,27 @@ def gen_ecdr64_loc(ecdr64_pos): return ret -class StreamZip(object): - def __init__(self, log, fgen, utf8=False, pre_crc=False): - self.log = log - self.fgen = fgen +class StreamZip(StreamArc): + def __init__( + self, + log: NamedLogger, + fgen: Generator[dict[str, Any], None, None], + utf8: bool = False, + pre_crc: bool = False, + ) -> None: + super(StreamZip, self).__init__(log, fgen) + self.utf8 = utf8 self.pre_crc = pre_crc self.pos = 0 - self.items = [] + self.items: list[tuple[str, int, int, int, int]] = [] - def _ct(self, buf): + def _ct(self, buf: bytes) -> bytes: self.pos += len(buf) return buf - def ser(self, f): + def ser(self, f: dict[str, Any]) -> Generator[bytes, None, None]: name = f["vp"] src = f["ap"] st = f["st"] @@ -218,9 +243,8 @@ class StreamZip(object): sz = st.st_size ts = st.st_mtime - crc = None + crc = 0 if self.pre_crc: - crc = 0 for buf in yieldfile(src): crc = zlib.crc32(buf, crc) @@ -230,7 +254,6 @@ class StreamZip(object): buf = gen_hdr(None, name, sz, ts, self.utf8, crc, self.pre_crc) yield self._ct(buf) - crc = crc or 0 for buf in yieldfile(src): if not self.pre_crc: crc = zlib.crc32(buf, crc) @@ -239,7 +262,7 @@ class StreamZip(object): crc &= 0xFFFFFFFF - self.items.append([name, sz, ts, crc, h_pos]) + self.items.append((name, sz, ts, crc, h_pos)) z64 = sz >= 4 * 1024 * 1024 * 1024 @@ -247,11 +270,11 @@ class StreamZip(object): buf = gen_fdesc(sz, crc, z64) yield self._ct(buf) - def gen(self): + def gen(self) -> Generator[bytes, None, None]: errors = [] for f in self.fgen: if "err" in f: - errors.append([f["vp"], f["err"]]) + errors.append((f["vp"], f["err"])) continue try: @@ -259,7 +282,7 @@ class StreamZip(object): yield x except: ex = min_ex(5, True).replace("\n", "\n-- ") - errors.append([f["vp"], ex]) + errors.append((f["vp"], ex)) if errors: errf, txt = errdesc(errors) diff --git a/copyparty/tcpsrv.py b/copyparty/tcpsrv.py index 7d2cb3e4..ae4c5f86 100644 --- a/copyparty/tcpsrv.py +++ b/copyparty/tcpsrv.py @@ -2,12 +2,15 @@ from __future__ import print_function, unicode_literals import re -import sys import socket +import sys -from .__init__ import MACOS, ANYWIN, unicode +from .__init__ import ANYWIN, MACOS, TYPE_CHECKING, unicode from .util import chkcmd +if TYPE_CHECKING: + from .svchub import SvcHub + class TcpSrv(object): """ @@ -15,16 +18,16 @@ class TcpSrv(object): which then uses the least busy HttpSrv to handle it """ - def __init__(self, hub): + def __init__(self, hub: "SvcHub"): self.hub = hub self.args = hub.args self.log = hub.log self.stopping = False - self.srv = [] + self.srv: list[socket.socket] = [] self.nsrv = 0 - ok = {} + ok: dict[str, list[int]] = {} for ip in self.args.i: ok[ip] = [] for port in self.args.p: @@ -34,8 +37,8 @@ class TcpSrv(object): ok[ip].append(port) except Exception as ex: if self.args.ign_ebind or self.args.ign_ebind_all: - m = "could not listen on {}:{}: {}" - self.log("tcpsrv", m.format(ip, port, ex), c=3) + t = "could not listen on {}:{}: {}" + self.log("tcpsrv", t.format(ip, port, ex), c=3) else: raise @@ -55,9 +58,9 @@ class TcpSrv(object): eps[x] = "external" msgs = [] - title_tab = {} + title_tab: dict[str, dict[str, int]] = {} title_vars = [x[1:] for x in self.args.wintitle.split(" ") if x.startswith("$")] - m = "available @ {}://{}:{}/ (\033[33m{}\033[0m)" + t = "available @ {}://{}:{}/ (\033[33m{}\033[0m)" for ip, desc in sorted(eps.items(), key=lambda x: x[1]): for port in sorted(self.args.p): if port not in ok.get(ip, ok.get("0.0.0.0", [])): @@ -69,7 +72,7 @@ class TcpSrv(object): elif self.args.https_only or port == 443: proto = "https" - msgs.append(m.format(proto, ip, port, desc)) + msgs.append(t.format(proto, ip, port, desc)) if not self.args.wintitle: continue @@ -98,13 +101,13 @@ class TcpSrv(object): if msgs: msgs[-1] += "\n" - for m in msgs: - self.log("tcpsrv", m) + for t in msgs: + self.log("tcpsrv", t) if self.args.wintitle: self._set_wintitle(title_tab) - def _listen(self, ip, port): + def _listen(self, ip: str, port: int) -> None: srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) srv.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @@ -120,7 +123,7 @@ class TcpSrv(object): raise raise Exception(e) - def run(self): + def run(self) -> None: for srv in self.srv: srv.listen(self.args.nc) ip, port = srv.getsockname() @@ -130,9 +133,9 @@ class TcpSrv(object): if self.args.q: print(msg) - self.hub.broker.put(False, "listen", srv) + self.hub.broker.say("listen", srv) - def shutdown(self): + def shutdown(self) -> None: self.stopping = True try: for srv in self.srv: @@ -142,14 +145,14 @@ class TcpSrv(object): self.log("tcpsrv", "ok bye") - def ips_linux_ifconfig(self): + def ips_linux_ifconfig(self) -> dict[str, str]: # for termux try: txt, _ = chkcmd(["ifconfig"]) except: return {} - eps = {} + eps: dict[str, str] = {} dev = None ip = None up = None @@ -171,7 +174,7 @@ class TcpSrv(object): return eps - def ips_linux(self): + def ips_linux(self) -> dict[str, str]: try: txt, _ = chkcmd(["ip", "addr"]) except: @@ -180,21 +183,21 @@ class TcpSrv(object): r = re.compile(r"^\s+inet ([^ ]+)/.* (.*)") ri = re.compile(r"^\s*[0-9]+\s*:.*") up = False - eps = {} + eps: dict[str, str] = {} for ln in txt.split("\n"): if ri.match(ln): up = "UP" in re.split("[>,< ]", ln) try: - ip, dev = r.match(ln.rstrip()).groups() + ip, dev = r.match(ln.rstrip()).groups() # type: ignore eps[ip] = dev + ("" if up else ", \033[31mLINK-DOWN") except: pass return eps - def ips_macos(self): - eps = {} + def ips_macos(self) -> dict[str, str]: + eps: dict[str, str] = {} try: txt, _ = chkcmd(["ifconfig"]) except: @@ -202,7 +205,7 @@ class TcpSrv(object): rdev = re.compile(r"^([^ ]+):") rip = re.compile(r"^\tinet ([0-9\.]+) ") - dev = None + dev = "UNKNOWN" for ln in txt.split("\n"): m = rdev.match(ln) if m: @@ -211,17 +214,17 @@ class TcpSrv(object): m = rip.match(ln) if m: eps[m.group(1)] = dev - dev = None + dev = "UNKNOWN" return eps - def ips_windows_ipconfig(self): - eps = {} - offs = {} + def ips_windows_ipconfig(self) -> tuple[dict[str, str], set[str]]: + eps: dict[str, str] = {} + offs: set[str] = set() try: txt, _ = chkcmd(["ipconfig"]) except: - return eps + return eps, offs rdev = re.compile(r"(^[^ ].*):$") rip = re.compile(r"^ +IPv?4? [^:]+: *([0-9\.]{7,15})$") @@ -231,12 +234,12 @@ class TcpSrv(object): m = rdev.match(ln) if m: if dev and dev not in eps.values(): - offs[dev] = 1 + offs.add(dev) dev = m.group(1).split(" adapter ", 1)[-1] if dev and roff.match(ln): - offs[dev] = 1 + offs.add(dev) dev = None m = rip.match(ln) @@ -245,12 +248,12 @@ class TcpSrv(object): dev = None if dev and dev not in eps.values(): - offs[dev] = 1 + offs.add(dev) return eps, offs - def ips_windows_netsh(self): - eps = {} + def ips_windows_netsh(self) -> dict[str, str]: + eps: dict[str, str] = {} try: txt, _ = chkcmd("netsh interface ip show address".split()) except: @@ -270,7 +273,7 @@ class TcpSrv(object): return eps - def detect_interfaces(self, listen_ips): + def detect_interfaces(self, listen_ips: list[str]) -> dict[str, str]: if MACOS: eps = self.ips_macos() elif ANYWIN: @@ -317,7 +320,7 @@ class TcpSrv(object): return eps - def _set_wintitle(self, vs): + def _set_wintitle(self, vs: dict[str, dict[str, int]]) -> None: vs["all"] = vs.get("all", {"Local-Only": 1}) vs["pub"] = vs.get("pub", vs["all"]) diff --git a/copyparty/th_cli.py b/copyparty/th_cli.py index e8a18e0e..9eb49a4f 100644 --- a/copyparty/th_cli.py +++ b/copyparty/th_cli.py @@ -3,13 +3,23 @@ from __future__ import print_function, unicode_literals import os -from .util import Cooldown -from .th_srv import thumb_path, HAVE_WEBP +from .__init__ import TYPE_CHECKING +from .authsrv import VFS from .bos import bos +from .th_srv import HAVE_WEBP, thumb_path +from .util import Cooldown + +try: + from typing import Optional, Union +except: + pass + +if TYPE_CHECKING: + from .httpsrv import HttpSrv class ThumbCli(object): - def __init__(self, hsrv): + def __init__(self, hsrv: "HttpSrv") -> None: self.broker = hsrv.broker self.log_func = hsrv.log self.args = hsrv.args @@ -34,10 +44,10 @@ class ThumbCli(object): d = next((x for x in self.args.th_dec if x in ("vips", "pil")), None) self.can_webp = HAVE_WEBP or d == "vips" - def log(self, msg, c=0): + def log(self, msg: str, c: Union[int, str] = 0) -> None: self.log_func("thumbcli", msg, c) - def get(self, dbv, rem, mtime, fmt): + def get(self, dbv: VFS, rem: str, mtime: float, fmt: str) -> Optional[str]: ptop = dbv.realpath ext = rem.rsplit(".")[-1].lower() if ext not in self.thumbable or "dthumb" in dbv.flags: @@ -106,17 +116,17 @@ class ThumbCli(object): if ret: tdir = os.path.dirname(tpath) if self.cooldown.poke(tdir): - self.broker.put(False, "thumbsrv.poke", tdir) + self.broker.say("thumbsrv.poke", tdir) if want_opus: # audio files expire individually if self.cooldown.poke(tpath): - self.broker.put(False, "thumbsrv.poke", tpath) + self.broker.say("thumbsrv.poke", tpath) return ret if abort: return None - x = self.broker.put(True, "thumbsrv.get", ptop, rem, mtime, fmt) - return x.get() + x = self.broker.ask("thumbsrv.get", ptop, rem, mtime, fmt) + return x.get() # type: ignore diff --git a/copyparty/th_srv.py b/copyparty/th_srv.py index dc3965c7..2a9e5f02 100644 --- a/copyparty/th_srv.py +++ b/copyparty/th_srv.py @@ -1,18 +1,28 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import os -import time -import shutil import base64 import hashlib -import threading +import os +import shutil import subprocess as sp +import threading +import time -from .util import fsenc, vsplit, statdir, runcmd, Queue, Cooldown, BytesIO, min_ex +from queue import Queue + +from .__init__ import TYPE_CHECKING from .bos import bos from .mtag import HAVE_FFMPEG, HAVE_FFPROBE, ffprobe +from .util import BytesIO, Cooldown, fsenc, min_ex, runcmd, statdir, vsplit +try: + from typing import Optional, Union +except: + pass + +if TYPE_CHECKING: + from .svchub import SvcHub HAVE_PIL = False HAVE_HEIF = False @@ -20,7 +30,7 @@ HAVE_AVIF = False HAVE_WEBP = False try: - from PIL import Image, ImageOps, ExifTags + from PIL import ExifTags, Image, ImageOps HAVE_PIL = True try: @@ -47,14 +57,13 @@ except: pass try: - import pyvips - HAVE_VIPS = True + import pyvips except: HAVE_VIPS = False -def thumb_path(histpath, rem, mtime, fmt): +def thumb_path(histpath: str, rem: str, mtime: float, fmt: str) -> str: # base16 = 16 = 256 # b64-lc = 38 = 1444 # base64 = 64 = 4096 @@ -80,7 +89,7 @@ def thumb_path(histpath, rem, mtime, fmt): class ThumbSrv(object): - def __init__(self, hub): + def __init__(self, hub: "SvcHub") -> None: self.hub = hub self.asrv = hub.asrv self.args = hub.args @@ -91,17 +100,17 @@ class ThumbSrv(object): self.poke_cd = Cooldown(self.args.th_poke) self.mutex = threading.Lock() - self.busy = {} + self.busy: dict[str, list[threading.Condition]] = {} self.stopping = False self.nthr = max(1, self.args.th_mt) - self.q = Queue(self.nthr * 4) + self.q: Queue[Optional[tuple[str, str]]] = Queue(self.nthr * 4) for n in range(self.nthr): - t = threading.Thread( + thr = threading.Thread( target=self.worker, name="thumb-{}-{}".format(n, self.nthr) ) - t.daemon = True - t.start() + thr.daemon = True + thr.start() want_ff = not self.args.no_vthumb or not self.args.no_athumb if want_ff and (not HAVE_FFMPEG or not HAVE_FFPROBE): @@ -122,7 +131,7 @@ class ThumbSrv(object): t.start() self.fmt_pil, self.fmt_vips, self.fmt_ffi, self.fmt_ffv, self.fmt_ffa = [ - {x: True for x in y.split(",")} + set(y.split(",")) for y in [ self.args.th_r_pil, self.args.th_r_vips, @@ -134,37 +143,37 @@ class ThumbSrv(object): if not HAVE_HEIF: for f in "heif heifs heic heics".split(" "): - self.fmt_pil.pop(f, None) + self.fmt_pil.discard(f) if not HAVE_AVIF: for f in "avif avifs".split(" "): - self.fmt_pil.pop(f, None) + self.fmt_pil.discard(f) - self.thumbable = {} + self.thumbable: set[str] = set() if "pil" in self.args.th_dec: - self.thumbable.update(self.fmt_pil) + self.thumbable |= self.fmt_pil if "vips" in self.args.th_dec: - self.thumbable.update(self.fmt_vips) + self.thumbable |= self.fmt_vips if "ff" in self.args.th_dec: - for t in [self.fmt_ffi, self.fmt_ffv, self.fmt_ffa]: - self.thumbable.update(t) + for zss in [self.fmt_ffi, self.fmt_ffv, self.fmt_ffa]: + self.thumbable |= zss - def log(self, msg, c=0): + def log(self, msg: str, c: Union[int, str] = 0) -> None: self.log_func("thumb", msg, c) - def shutdown(self): + def shutdown(self) -> None: self.stopping = True for _ in range(self.nthr): self.q.put(None) - def stopped(self): + def stopped(self) -> bool: with self.mutex: return not self.nthr - def get(self, ptop, rem, mtime, fmt): + def get(self, ptop: str, rem: str, mtime: float, fmt: str) -> Optional[str]: histpath = self.asrv.vfs.histtab.get(ptop) if not histpath: self.log("no histpath for [{}]".format(ptop)) @@ -191,7 +200,7 @@ class ThumbSrv(object): do_conv = True if do_conv: - self.q.put([abspath, tpath]) + self.q.put((abspath, tpath)) self.log("conv {} \033[0m{}".format(tpath, abspath), c=6) while not self.stopping: @@ -212,7 +221,7 @@ class ThumbSrv(object): return None - def getcfg(self): + def getcfg(self) -> dict[str, set[str]]: return { "thumbable": self.thumbable, "pil": self.fmt_pil, @@ -222,7 +231,7 @@ class ThumbSrv(object): "ffa": self.fmt_ffa, } - def worker(self): + def worker(self) -> None: while not self.stopping: task = self.q.get() if not task: @@ -253,7 +262,7 @@ class ThumbSrv(object): except: msg = "{} could not create thumbnail of {}\n{}" msg = msg.format(fun.__name__, abspath, min_ex()) - c = 1 if " "Image.Image": # exif_transpose is expensive (loads full image + unconditional copy) r = max(*self.res) * 2 im.thumbnail((r, r), resample=Image.LANCZOS) @@ -295,7 +304,7 @@ class ThumbSrv(object): return im - def conv_pil(self, abspath, tpath): + def conv_pil(self, abspath: str, tpath: str) -> None: with Image.open(fsenc(abspath)) as im: try: im = self.fancy_pillow(im) @@ -324,7 +333,7 @@ class ThumbSrv(object): im.save(tpath, **args) - def conv_vips(self, abspath, tpath): + def conv_vips(self, abspath: str, tpath: str) -> None: crops = ["centre", "none"] if self.args.th_no_crop: crops = ["none"] @@ -342,18 +351,17 @@ class ThumbSrv(object): img.write_to_file(tpath, Q=40) - def conv_ffmpeg(self, abspath, tpath): + def conv_ffmpeg(self, abspath: str, tpath: str) -> None: ret, _ = ffprobe(abspath) if not ret: return ext = abspath.rsplit(".")[-1].lower() if ext in ["h264", "h265"] or ext in self.fmt_ffi: - seek = [] + seek: list[bytes] = [] else: dur = ret[".dur"][1] if ".dur" in ret else 4 - seek = "{:.0f}".format(dur / 3) - seek = [b"-ss", seek.encode("utf-8")] + seek = [b"-ss", "{:.0f}".format(dur / 3).encode("utf-8")] scale = "scale={0}:{1}:force_original_aspect_ratio=" if self.args.th_no_crop: @@ -361,7 +369,7 @@ class ThumbSrv(object): else: scale += "increase,crop={0}:{1},setsar=1:1" - scale = scale.format(*list(self.res)).encode("utf-8") + bscale = scale.format(*list(self.res)).encode("utf-8") # fmt: off cmd = [ b"ffmpeg", @@ -373,7 +381,7 @@ class ThumbSrv(object): cmd += [ b"-i", fsenc(abspath), b"-map", b"0:v:0", - b"-vf", scale, + b"-vf", bscale, b"-frames:v", b"1", b"-metadata:s:v:0", b"rotate=0", ] @@ -395,14 +403,14 @@ class ThumbSrv(object): cmd += [fsenc(tpath)] self._run_ff(cmd) - def _run_ff(self, cmd): + def _run_ff(self, cmd: list[bytes]) -> None: # self.log((b" ".join(cmd)).decode("utf-8")) ret, _, serr = runcmd(cmd, timeout=self.args.th_convt) if not ret: return - c = "1;30" - m = "FFmpeg failed (probably a corrupt video file):\n" + c: Union[str, int] = "1;30" + t = "FFmpeg failed (probably a corrupt video file):\n" if cmd[-1].lower().endswith(b".webp") and ( "Error selecting an encoder" in serr or "Automatic encoder selection failed" in serr @@ -410,14 +418,14 @@ class ThumbSrv(object): or "Please choose an encoder manually" in serr ): self.args.th_ff_jpg = True - m = "FFmpeg failed because it was compiled without libwebp; enabling --th-ff-jpg to force jpeg output:\n" + t = "FFmpeg failed because it was compiled without libwebp; enabling --th-ff-jpg to force jpeg output:\n" c = 1 if ( "Requested resampling engine is unavailable" in serr or "output pad on Parsed_aresample_" in serr ): - m = "FFmpeg failed because it was compiled without libsox; you must set --th-ff-swr to force swr resampling:\n" + t = "FFmpeg failed because it was compiled without libsox; you must set --th-ff-swr to force swr resampling:\n" c = 1 lines = serr.strip("\n").split("\n") @@ -428,10 +436,10 @@ class ThumbSrv(object): if len(txt) > 5000: txt = txt[:2500] + "...\nff: [...]\nff: ..." + txt[-2500:] - self.log(m + txt, c=c) + self.log(t + txt, c=c) raise sp.CalledProcessError(ret, (cmd[0], b"...", cmd[-1])) - def conv_spec(self, abspath, tpath): + def conv_spec(self, abspath: str, tpath: str) -> None: ret, _ = ffprobe(abspath) if "ac" not in ret: raise Exception("not audio") @@ -473,7 +481,7 @@ class ThumbSrv(object): cmd += [fsenc(tpath)] self._run_ff(cmd) - def conv_opus(self, abspath, tpath): + def conv_opus(self, abspath: str, tpath: str) -> None: if self.args.no_acode: raise Exception("disabled in server config") @@ -521,7 +529,7 @@ class ThumbSrv(object): # fmt: on self._run_ff(cmd) - def poke(self, tdir): + def poke(self, tdir: str) -> None: if not self.poke_cd.poke(tdir): return @@ -533,7 +541,7 @@ class ThumbSrv(object): except: pass - def cleaner(self): + def cleaner(self) -> None: interval = self.args.th_clean while True: time.sleep(interval) @@ -548,14 +556,14 @@ class ThumbSrv(object): self.log("\033[Jcln ok; rm {} dirs".format(ndirs)) - def clean(self, histpath): + def clean(self, histpath: str) -> int: ret = 0 for cat in ["th", "ac"]: - ret += self._clean(histpath, cat, None) + ret += self._clean(histpath, cat, "") return ret - def _clean(self, histpath, cat, thumbpath): + def _clean(self, histpath: str, cat: str, thumbpath: str) -> int: if not thumbpath: thumbpath = os.path.join(histpath, cat) @@ -564,10 +572,10 @@ class ThumbSrv(object): maxage = getattr(self.args, cat + "_maxage") now = time.time() prev_b64 = None - prev_fp = None + prev_fp = "" try: - ents = statdir(self.log, not self.args.no_scandir, False, thumbpath) - ents = sorted(list(ents)) + t1 = statdir(self.log_func, not self.args.no_scandir, False, thumbpath) + ents = sorted(list(t1)) except: return 0 diff --git a/copyparty/u2idx.py b/copyparty/u2idx.py index 073ed3d1..8a342042 100644 --- a/copyparty/u2idx.py +++ b/copyparty/u2idx.py @@ -1,34 +1,37 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import re -import os -import time import calendar +import os +import re import threading +import time from operator import itemgetter -from .__init__ import ANYWIN, unicode -from .util import absreal, s3dec, Pebkac, min_ex, gen_filekey, quotep +from .__init__ import ANYWIN, TYPE_CHECKING, unicode from .bos import bos from .up2k import up2k_wark_from_hashlist +from .util import HAVE_SQLITE3, Pebkac, absreal, gen_filekey, min_ex, quotep, s3dec - -try: - HAVE_SQLITE3 = True +if HAVE_SQLITE3: import sqlite3 -except: - HAVE_SQLITE3 = False - try: from pathlib import Path except: pass +try: + from typing import Any, Optional, Union +except: + pass + +if TYPE_CHECKING: + from .httpconn import HttpConn + class U2idx(object): - def __init__(self, conn): + def __init__(self, conn: "HttpConn") -> None: self.log_func = conn.log_func self.asrv = conn.asrv self.args = conn.args @@ -38,19 +41,21 @@ class U2idx(object): self.log("your python does not have sqlite3; searching will be disabled") return - self.active_id = None - self.active_cur = None - self.cur = {} - self.mem_cur = sqlite3.connect(":memory:") + self.active_id = "" + self.active_cur: Optional["sqlite3.Cursor"] = None + self.cur: dict[str, "sqlite3.Cursor"] = {} + self.mem_cur = sqlite3.connect(":memory:").cursor() self.mem_cur.execute(r"create table a (b text)") - self.p_end = None - self.p_dur = 0 + self.p_end = 0.0 + self.p_dur = 0.0 - def log(self, msg, c=0): + def log(self, msg: str, c: Union[int, str] = 0) -> None: self.log_func("u2idx", msg, c) - def fsearch(self, vols, body): + def fsearch( + self, vols: list[tuple[str, str, dict[str, Any]]], body: dict[str, Any] + ) -> list[dict[str, Any]]: """search by up2k hashlist""" if not HAVE_SQLITE3: return [] @@ -60,14 +65,14 @@ class U2idx(object): wark = up2k_wark_from_hashlist(self.args.salt, fsize, fhash) uq = "substr(w,1,16) = ? and w = ?" - uv = [wark[:16], wark] + uv: list[Union[str, int]] = [wark[:16], wark] try: return self.run_query(vols, uq, uv, True, False, 99999)[0] except: raise Pebkac(500, min_ex()) - def get_cur(self, ptop): + def get_cur(self, ptop: str) -> Optional["sqlite3.Cursor"]: if not HAVE_SQLITE3: return None @@ -103,13 +108,16 @@ class U2idx(object): self.cur[ptop] = cur return cur - def search(self, vols, uq, lim): + def search( + self, vols: list[tuple[str, str, dict[str, Any]]], uq: str, lim: int + ) -> tuple[list[dict[str, Any]], list[str]]: """search by query params""" if not HAVE_SQLITE3: - return [] + return [], [] q = "" - va = [] + v: Union[str, int] = "" + va: list[Union[str, int]] = [] have_up = False # query has up.* operands have_mt = False is_key = True @@ -202,7 +210,7 @@ class U2idx(object): "%Y", ]: try: - v = calendar.timegm(time.strptime(v, fmt)) + v = calendar.timegm(time.strptime(str(v), fmt)) break except: pass @@ -230,11 +238,12 @@ class U2idx(object): # lowercase tag searches m = ptn_lc.search(q) - if not m or not ptn_lcv.search(unicode(v)): + zs = unicode(v) + if not m or not ptn_lcv.search(zs): continue va.pop() - va.append(v.lower()) + va.append(zs.lower()) q = q[: m.start()] field, oper = m.groups() @@ -248,8 +257,16 @@ class U2idx(object): except Exception as ex: raise Pebkac(500, repr(ex)) - def run_query(self, vols, uq, uv, have_up, have_mt, lim): - done_flag = [] + def run_query( + self, + vols: list[tuple[str, str, dict[str, Any]]], + uq: str, + uv: list[Union[str, int]], + have_up: bool, + have_mt: bool, + lim: int, + ) -> tuple[list[dict[str, Any]], list[str]]: + done_flag: list[bool] = [] self.active_id = "{:.6f}_{}".format( time.time(), threading.current_thread().ident ) @@ -266,13 +283,11 @@ class U2idx(object): if not uq or not uv: uq = "select * from up" - uv = () + uv = [] elif have_mt: uq = "select up.*, substr(up.w,1,16) mtw from up where " + uq - uv = tuple(uv) else: uq = "select up.* from up where " + uq - uv = tuple(uv) self.log("qs: {!r} {!r}".format(uq, uv)) @@ -292,11 +307,10 @@ class U2idx(object): v = vtop + "/" vuv.append(v) - vuv = tuple(vuv) sret = [] fk = flags.get("fk") - c = cur.execute(uq, vuv) + c = cur.execute(uq, tuple(vuv)) for hit in c: w, ts, sz, rd, fn, ip, at = hit[:7] lim -= 1 @@ -340,7 +354,7 @@ class U2idx(object): # print("[{}] {}".format(ptop, sret)) done_flag.append(True) - self.active_id = None + self.active_id = "" # undupe hits from multiple metadata keys if len(ret) > 1: @@ -354,11 +368,12 @@ class U2idx(object): return ret, list(taglist.keys()) - def terminator(self, identifier, done_flag): + def terminator(self, identifier: str, done_flag: list[bool]) -> None: for _ in range(self.timeout): time.sleep(1) if done_flag: return if identifier == self.active_id: + assert self.active_cur self.active_cur.connection.interrupt() diff --git a/copyparty/up2k.py b/copyparty/up2k.py index 3f4e02e4..aff68917 100644 --- a/copyparty/up2k.py +++ b/copyparty/up2k.py @@ -1,60 +1,83 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import re -import os -import time -import math -import json -import gzip -import stat -import shutil import base64 +import gzip import hashlib -import threading -import traceback +import json +import math +import os +import re +import shutil +import stat import subprocess as sp +import threading +import time +import traceback from copy import deepcopy -from .__init__ import WINDOWS, ANYWIN, PY2 -from .util import ( - Pebkac, - Queue, - ProgressPrinter, - SYMTIME, - fsenc, - absreal, - sanitize_fn, - ren_open, - atomic_move, - quotep, - vsplit, - w8b64enc, - w8b64dec, - s3enc, - s3dec, - rmdirs, - statdir, - s2hms, - min_ex, -) -from .bos import bos -from .authsrv import AuthSrv, LEELOO_DALLAS -from .mtag import MTag, MParser +from queue import Queue -try: - HAVE_SQLITE3 = True +from .__init__ import ANYWIN, PY2, TYPE_CHECKING, WINDOWS +from .authsrv import LEELOO_DALLAS, VFS, AuthSrv +from .bos import bos +from .mtag import MParser, MTag +from .util import ( + HAVE_SQLITE3, + SYMTIME, + Pebkac, + ProgressPrinter, + absreal, + atomic_move, + fsenc, + min_ex, + quotep, + ren_open, + rmdirs, + s2hms, + s3dec, + s3enc, + sanitize_fn, + statdir, + vsplit, + w8b64dec, + w8b64enc, +) + +if HAVE_SQLITE3: import sqlite3 -except: - HAVE_SQLITE3 = False DB_VER = 5 +try: + from typing import Any, Optional, Pattern, Union +except: + pass + +if TYPE_CHECKING: + from .svchub import SvcHub + + +class Dbw(object): + def __init__(self, c: "sqlite3.Cursor", n: int, t: float) -> None: + self.c = c + self.n = n + self.t = t + + +class Mpqe(object): + def __init__(self, mtp: dict[str, MParser], entags: set[str], w: str, abspath: str): + # mtp empty = mtag + self.mtp = mtp + self.entags = entags + self.w = w + self.abspath = abspath + class Up2k(object): - def __init__(self, hub): + def __init__(self, hub: "SvcHub") -> None: self.hub = hub - self.asrv = hub.asrv # type: AuthSrv + self.asrv: AuthSrv = hub.asrv self.args = hub.args self.log_func = hub.log @@ -62,25 +85,32 @@ class Up2k(object): self.salt = self.args.salt # state + self.gid = 0 self.mutex = threading.Lock() + self.pp: Optional[ProgressPrinter] = None self.rescan_cond = threading.Condition() - self.hashq = Queue() - self.tagq = Queue() + self.need_rescan: set[str] = set() + + self.registry: dict[str, dict[str, dict[str, Any]]] = {} + self.flags: dict[str, dict[str, Any]] = {} + self.droppable: dict[str, list[str]] = {} + self.volstate: dict[str, str] = {} + self.dupesched: dict[str, list[tuple[str, str, float]]] = {} + self.snap_persist_interval = 300 # persist unfinished index every 5 min + self.snap_discard_interval = 21600 # drop unfinished after 6 hours inactivity + self.snap_prev: dict[str, Optional[tuple[int, float]]] = {} + + self.mtag: Optional[MTag] = None + self.entags: dict[str, set[str]] = {} + self.mtp_parsers: dict[str, dict[str, MParser]] = {} + self.pending_tags: list[tuple[set[str], str, str, dict[str, Any]]] = [] + self.hashq: Queue[tuple[str, str, str, str, float]] = Queue() + self.tagq: Queue[tuple[str, str, str, str]] = Queue() self.n_hashq = 0 self.n_tagq = 0 - self.gid = 0 - self.volstate = {} - self.need_rescan = {} - self.dupesched = {} - self.registry = {} - self.droppable = {} - self.entags = {} - self.flags = {} - self.cur = {} - self.mtag = None - self.pending_tags = None - self.mtp_parsers = {} + self.mpool_used = False + self.cur: dict[str, "sqlite3.Cursor"] = {} self.mem_cur = None self.sqlite_ver = None self.no_expr_idx = False @@ -94,7 +124,7 @@ class Up2k(object): if ANYWIN: # usually fails to set lastmod too quickly - self.lastmod_q = [] + self.lastmod_q: list[tuple[str, int, tuple[int, int]]] = [] thr = threading.Thread(target=self._lastmodder, name="up2k-lastmod") thr.daemon = True thr.start() @@ -108,7 +138,7 @@ class Up2k(object): if self.args.no_fastboot: self.deferred_init() - def init_vols(self): + def init_vols(self) -> None: if self.args.no_fastboot: return @@ -116,15 +146,15 @@ class Up2k(object): t.daemon = True t.start() - def reload(self): + def reload(self) -> None: self.gid += 1 self.log("reload #{} initiated".format(self.gid)) all_vols = self.asrv.vfs.all_vols self.rescan(all_vols, list(all_vols.keys()), True) - def deferred_init(self): + def deferred_init(self) -> None: all_vols = self.asrv.vfs.all_vols - have_e2d = self.init_indexes(all_vols) + have_e2d = self.init_indexes(all_vols, []) thr = threading.Thread(target=self._snapshot, name="up2k-snapshot") thr.daemon = True @@ -150,11 +180,11 @@ class Up2k(object): thr.daemon = True thr.start() - def log(self, msg, c=0): + def log(self, msg: str, c: Union[int, str] = 0) -> None: self.log_func("up2k", msg + "\033[K", c) - def get_state(self): - mtpq = 0 + def get_state(self) -> str: + mtpq: Union[int, str] = 0 q = "select count(w) from mt where k = 't:mtp'" got_lock = False if PY2 else self.mutex.acquire(timeout=0.5) if got_lock: @@ -165,19 +195,19 @@ class Up2k(object): pass self.mutex.release() else: - mtpq = "?" + mtpq = "(?)" ret = { "volstate": self.volstate, - "scanning": hasattr(self, "pp"), + "scanning": bool(self.pp), "hashq": self.n_hashq, "tagq": self.n_tagq, "mtpq": mtpq, } return json.dumps(ret, indent=4) - def rescan(self, all_vols, scan_vols, wait): - if not wait and hasattr(self, "pp"): + def rescan(self, all_vols: dict[str, VFS], scan_vols: list[str], wait: bool) -> str: + if not wait and self.pp: return "cannot initiate; scan is already in progress" args = (all_vols, scan_vols) @@ -188,11 +218,11 @@ class Up2k(object): ) t.daemon = True t.start() - return None + return "" - def _sched_rescan(self): + def _sched_rescan(self) -> None: volage = {} - cooldown = 0 + cooldown = 0.0 timeout = time.time() + 3 while True: timeout = max(timeout, cooldown) @@ -204,7 +234,7 @@ class Up2k(object): if now < cooldown: continue - if hasattr(self, "pp"): + if self.pp: cooldown = now + 5 continue @@ -220,19 +250,19 @@ class Up2k(object): deadline = volage[vp] + maxage if deadline <= now: - self.need_rescan[vp] = 1 + self.need_rescan.add(vp) timeout = min(timeout, deadline) - vols = list(sorted(self.need_rescan.keys())) - self.need_rescan = {} + vols = list(sorted(self.need_rescan)) + self.need_rescan.clear() if vols: cooldown = now + 10 err = self.rescan(self.asrv.vfs.all_vols, vols, False) if err: for v in vols: - self.need_rescan[v] = True + self.need_rescan.add(v) continue @@ -272,7 +302,7 @@ class Up2k(object): if vp: fvp = "{}/{}".format(vp, fvp) - self._handle_rm(LEELOO_DALLAS, None, fvp) + self._handle_rm(LEELOO_DALLAS, "", fvp) nrm += 1 if nrm: @@ -288,12 +318,12 @@ class Up2k(object): if hits: timeout = min(timeout, now + lifetime - (now - hits[0])) - def _vis_job_progress(self, job): + def _vis_job_progress(self, job: dict[str, Any]) -> str: perc = 100 - (len(job["need"]) * 100.0 / len(job["hash"])) path = os.path.join(job["ptop"], job["prel"], job["name"]) return "{:5.1f}% {}".format(perc, path) - def _vis_reg_progress(self, reg): + def _vis_reg_progress(self, reg: dict[str, dict[str, Any]]) -> list[str]: ret = [] for _, job in reg.items(): if job["need"]: @@ -301,7 +331,7 @@ class Up2k(object): return ret - def _expr_idx_filter(self, flags): + def _expr_idx_filter(self, flags: dict[str, Any]) -> tuple[bool, dict[str, Any]]: if not self.no_expr_idx: return False, flags @@ -311,19 +341,19 @@ class Up2k(object): return True, ret - def init_indexes(self, all_vols, scan_vols=None): + def init_indexes(self, all_vols: dict[str, VFS], scan_vols: list[str]) -> bool: gid = self.gid - while hasattr(self, "pp") and gid == self.gid: + while self.pp and gid == self.gid: time.sleep(0.1) if gid != self.gid: - return + return False if gid: self.log("reload #{} running".format(self.gid)) self.pp = ProgressPrinter() - vols = all_vols.values() + vols = list(all_vols.values()) t0 = time.time() have_e2d = False @@ -377,9 +407,9 @@ class Up2k(object): # e2ds(a) volumes first for vol in vols: - en = {} + en: set[str] = set() if "mte" in vol.flags: - en = {k: True for k in vol.flags["mte"].split(",")} + en = set(vol.flags["mte"].split(",")) self.entags[vol.realpath] = en @@ -393,11 +423,11 @@ class Up2k(object): need_vac[vol] = True if "e2ts" not in vol.flags: - m = "online, idle" + t = "online, idle" else: - m = "online (tags pending)" + t = "online (tags pending)" - self.volstate[vol.vpath] = m + self.volstate[vol.vpath] = t # open the rest + do any e2ts(a) needed_mutagen = False @@ -405,9 +435,9 @@ class Up2k(object): if "e2ts" not in vol.flags: continue - m = "online (reading tags)" - self.volstate[vol.vpath] = m - self.log("{} [{}]".format(m, vol.realpath)) + t = "online (reading tags)" + self.volstate[vol.vpath] = t + self.log("{} [{}]".format(t, vol.realpath)) nadd, nrm, success = self._build_tags_index(vol) if not success: @@ -419,7 +449,9 @@ class Up2k(object): self.volstate[vol.vpath] = "online (mtp soon)" for vol in need_vac: - cur, _ = self.register_vpath(vol.realpath, vol.flags) + reg = self.register_vpath(vol.realpath, vol.flags) + assert reg + cur, _ = reg with self.mutex: cur.connection.commit() cur.execute("vacuum") @@ -435,23 +467,25 @@ class Up2k(object): thr = None if self.mtag: - m = "online (running mtp)" + t = "online (running mtp)" if scan_vols: thr = threading.Thread(target=self._run_all_mtp, name="up2k-mtp-scan") thr.daemon = True else: - del self.pp - m = "online, idle" + self.pp = None + t = "online, idle" for vol in vols: - self.volstate[vol.vpath] = m + self.volstate[vol.vpath] = t if thr: thr.start() return have_e2d - def register_vpath(self, ptop, flags): + def register_vpath( + self, ptop: str, flags: dict[str, Any] + ) -> Optional[tuple["sqlite3.Cursor", str]]: histpath = self.asrv.vfs.histtab.get(ptop) if not histpath: self.log("no histpath for [{}]".format(ptop)) @@ -460,7 +494,7 @@ class Up2k(object): db_path = os.path.join(histpath, "up2k.db") if ptop in self.registry: try: - return [self.cur[ptop], db_path] + return self.cur[ptop], db_path except: return None @@ -514,14 +548,14 @@ class Up2k(object): else: drp = [x for x in drp if x in reg] - m = "loaded snap {} |{}| ({})".format(path, len(reg.keys()), len(drp or [])) - m = [m] + self._vis_reg_progress(reg) - self.log("\n".join(m)) + t = "loaded snap {} |{}| ({})".format(path, len(reg.keys()), len(drp or [])) + ta = [t] + self._vis_reg_progress(reg) + self.log("\n".join(ta)) self.flags[ptop] = flags self.registry[ptop] = reg self.droppable[ptop] = drp or [] - self.regdrop(ptop, None) + self.regdrop(ptop, "") if not HAVE_SQLITE3 or "e2d" not in flags or "d2d" in flags: return None @@ -530,23 +564,25 @@ class Up2k(object): try: cur = self._open_db(db_path) self.cur[ptop] = cur - return [cur, db_path] + return cur, db_path except: msg = "cannot use database at [{}]:\n{}" self.log(msg.format(ptop, traceback.format_exc())) return None - def _build_file_index(self, vol, all_vols): + def _build_file_index(self, vol: VFS, all_vols: list[VFS]) -> tuple[bool, bool]: do_vac = False top = vol.realpath rei = vol.flags.get("noidx") reh = vol.flags.get("nohash") with self.mutex: - cur, _ = self.register_vpath(top, vol.flags) + reg = self.register_vpath(top, vol.flags) + assert reg and self.pp + cur, _ = reg - dbw = [cur, 0, time.time()] - self.pp.n = next(dbw[0].execute("select count(w) from up"))[0] + db = Dbw(cur, 0, time.time()) + self.pp.n = next(db.c.execute("select count(w) from up"))[0] excl = [ vol.realpath + "/" + d.vpath[len(vol.vpath) :].lstrip("/") @@ -558,37 +594,47 @@ class Up2k(object): if WINDOWS: excl = [x.replace("/", "\\") for x in excl] - excl = set(excl) rtop = absreal(top) n_add = n_rm = 0 try: - n_add = self._build_dir(dbw, top, excl, top, rtop, rei, reh, []) - n_rm = self._drop_lost(dbw[0], top) + n_add = self._build_dir(db, top, set(excl), top, rtop, rei, reh, []) + n_rm = self._drop_lost(db.c, top) except: - m = "failed to index volume [{}]:\n{}" - self.log(m.format(top, min_ex()), c=1) + t = "failed to index volume [{}]:\n{}" + self.log(t.format(top, min_ex()), c=1) - if dbw[1]: - self.log("commit {} new files".format(dbw[1])) + if db.n: + self.log("commit {} new files".format(db.n)) - dbw[0].connection.commit() + db.c.connection.commit() - return True, n_add or n_rm or do_vac + return True, bool(n_add or n_rm or do_vac) - def _build_dir(self, dbw, top, excl, cdir, rcdir, rei, reh, seen): + def _build_dir( + self, + db: Dbw, + top: str, + excl: set[str], + cdir: str, + rcdir: str, + rei: Optional[Pattern[str]], + reh: Optional[Pattern[str]], + seen: list[str], + ) -> int: if rcdir in seen: - m = "bailing from symlink loop,\n prev: {}\n curr: {}\n from: {}" - self.log(m.format(seen[-1], rcdir, cdir), 3) + t = "bailing from symlink loop,\n prev: {}\n curr: {}\n from: {}" + self.log(t.format(seen[-1], rcdir, cdir), 3) return 0 seen = seen + [rcdir] + assert self.pp and self.mem_cur self.pp.msg = "a{} {}".format(self.pp.n, cdir) ret = 0 seen_files = {} # != inames; files-only for dropcheck g = statdir(self.log_func, not self.args.no_scandir, False, cdir) - g = sorted(g) - inames = {x[0]: 1 for x in g} - for iname, inf in g: + gl = sorted(g) + inames = {x[0]: 1 for x in gl} + for iname, inf in gl: abspath = os.path.join(cdir, iname) if rei and rei.search(abspath): continue @@ -605,10 +651,10 @@ class Up2k(object): continue # self.log(" dir: {}".format(abspath)) try: - ret += self._build_dir(dbw, top, excl, abspath, rap, rei, reh, seen) + ret += self._build_dir(db, top, excl, abspath, rap, rei, reh, seen) except: - m = "failed to index subdir [{}]:\n{}" - self.log(m.format(abspath, min_ex()), c=1) + t = "failed to index subdir [{}]:\n{}" + self.log(t.format(abspath, min_ex()), c=1) elif not stat.S_ISREG(inf.st_mode): self.log("skip type-{:x} file [{}]".format(inf.st_mode, abspath)) else: @@ -632,31 +678,31 @@ class Up2k(object): rd, fn = rp.rsplit("/", 1) if "/" in rp else ["", rp] sql = "select w, mt, sz from up where rd = ? and fn = ?" try: - c = dbw[0].execute(sql, (rd, fn)) + c = db.c.execute(sql, (rd, fn)) except: - c = dbw[0].execute(sql, s3enc(self.mem_cur, rd, fn)) + c = db.c.execute(sql, s3enc(self.mem_cur, rd, fn)) in_db = list(c.fetchall()) if in_db: self.pp.n -= 1 dw, dts, dsz = in_db[0] if len(in_db) > 1: - m = "WARN: multiple entries: [{}] => [{}] |{}|\n{}" + t = "WARN: multiple entries: [{}] => [{}] |{}|\n{}" rep_db = "\n".join([repr(x) for x in in_db]) - self.log(m.format(top, rp, len(in_db), rep_db)) + self.log(t.format(top, rp, len(in_db), rep_db)) dts = -1 if dts == lmod and dsz == sz and (nohash or dw[0] != "#"): continue - m = "reindex [{}] => [{}] ({}/{}) ({}/{})".format( + t = "reindex [{}] => [{}] ({}/{}) ({}/{})".format( top, rp, dts, lmod, dsz, sz ) - self.log(m) - self.db_rm(dbw[0], rd, fn) + self.log(t) + self.db_rm(db.c, rd, fn) ret += 1 - dbw[1] += 1 - in_db = None + db.n += 1 + in_db = [] self.pp.msg = "a{} {}".format(self.pp.n, abspath) @@ -674,15 +720,15 @@ class Up2k(object): wark = up2k_wark_from_hashlist(self.salt, sz, hashes) - self.db_add(dbw[0], wark, rd, fn, lmod, sz, "", 0) - dbw[1] += 1 + self.db_add(db.c, wark, rd, fn, lmod, sz, "", 0) + db.n += 1 ret += 1 - td = time.time() - dbw[2] - if dbw[1] >= 4096 or td >= 60: - self.log("commit {} new files".format(dbw[1])) - dbw[0].connection.commit() - dbw[1] = 0 - dbw[2] = time.time() + td = time.time() - db.t + if db.n >= 4096 or td >= 60: + self.log("commit {} new files".format(db.n)) + db.c.connection.commit() + db.n = 0 + db.t = time.time() # drop missing files rd = cdir[len(top) + 1 :].strip("/") @@ -691,25 +737,26 @@ class Up2k(object): q = "select fn from up where rd = ?" try: - c = dbw[0].execute(q, (rd,)) + c = db.c.execute(q, (rd,)) except: - c = dbw[0].execute(q, ("//" + w8b64enc(rd),)) + c = db.c.execute(q, ("//" + w8b64enc(rd),)) hits = [w8b64dec(x[2:]) if x.startswith("//") else x for (x,) in c] rm_files = [x for x in hits if x not in seen_files] n_rm = len(rm_files) for fn in rm_files: - self.db_rm(dbw[0], rd, fn) + self.db_rm(db.c, rd, fn) if n_rm: self.log("forgot {} deleted files".format(n_rm)) return ret - def _drop_lost(self, cur, top): + def _drop_lost(self, cur: "sqlite3.Cursor", top: str) -> int: rm = [] n_rm = 0 nchecked = 0 + assert self.pp # `_build_dir` did all the files, now do dirs ndirs = next(cur.execute("select count(distinct rd) from up"))[0] c = cur.execute("select distinct rd from up order by rd desc") @@ -743,13 +790,16 @@ class Up2k(object): return n_rm - def _build_tags_index(self, vol): + def _build_tags_index(self, vol: VFS) -> tuple[int, int, bool]: ptop = vol.realpath with self.mutex: - _, db_path = self.register_vpath(ptop, vol.flags) - entags = self.entags[ptop] - flags = self.flags[ptop] - cur = self.cur[ptop] + reg = self.register_vpath(ptop, vol.flags) + + assert reg and self.pp and self.mtag + _, db_path = reg + entags = self.entags[ptop] + flags = self.flags[ptop] + cur = self.cur[ptop] n_add = 0 n_rm = 0 @@ -794,7 +844,8 @@ class Up2k(object): if not self.mtag: return n_add, n_rm, False - mpool = False + mpool: Optional[Queue[Mpqe]] = None + if self.mtag.prefer_mt and self.args.mtag_mt > 1: mpool = self._start_mpool() @@ -819,11 +870,10 @@ class Up2k(object): abspath = os.path.join(ptop, rd, fn) self.pp.msg = "c{} {}".format(n_left, abspath) - args = [entags, w, abspath] if not mpool: - n_tags = self._tag_file(c3, *args) + n_tags = self._tag_file(c3, entags, w, abspath) else: - mpool.put(["mtag"] + args) + mpool.put(Mpqe({}, entags, w, abspath)) # not registry cursor; do not self.mutex: n_tags = len(self._flush_mpool(c3)) @@ -850,7 +900,7 @@ class Up2k(object): return n_add, n_rm, True - def _flush_mpool(self, wcur): + def _flush_mpool(self, wcur: "sqlite3.Cursor") -> list[str]: ret = [] for x in self.pending_tags: self._tag_file(wcur, *x) @@ -859,7 +909,7 @@ class Up2k(object): self.pending_tags = [] return ret - def _run_all_mtp(self): + def _run_all_mtp(self) -> None: gid = self.gid t0 = time.time() for ptop, flags in self.flags.items(): @@ -870,12 +920,12 @@ class Up2k(object): msg = "mtp finished in {:.2f} sec ({})" self.log(msg.format(td, s2hms(td, True))) - del self.pp + self.pp = None for k in list(self.volstate.keys()): if "OFFLINE" not in self.volstate[k]: self.volstate[k] = "online, idle" - def _run_one_mtp(self, ptop, gid): + def _run_one_mtp(self, ptop: str, gid: int) -> None: if gid != self.gid: return @@ -915,8 +965,8 @@ class Up2k(object): break q = "select w from mt where k = 't:mtp' limit ?" - warks = cur.execute(q, (batch_sz,)).fetchall() - warks = [x[0] for x in warks] + zq = cur.execute(q, (batch_sz,)).fetchall() + warks = [str(x[0]) for x in zq] jobs = [] for w in warks: q = "select rd, fn from up where substr(w,1,16)=? limit 1" @@ -925,8 +975,8 @@ class Up2k(object): abspath = os.path.join(ptop, rd, fn) q = "select k from mt where w = ?" - have = cur.execute(q, (w,)).fetchall() - have = [x[0] for x in have] + zq2 = cur.execute(q, (w,)).fetchall() + have: dict[str, Union[str, float]] = {x[0]: 1 for x in zq2} parsers = self._get_parsers(ptop, have, abspath) if not parsers: @@ -937,7 +987,7 @@ class Up2k(object): if w in in_progress: continue - jobs.append([parsers, None, w, abspath]) + jobs.append(Mpqe(parsers, set(), w, abspath)) in_progress[w] = True with self.mutex: @@ -997,7 +1047,9 @@ class Up2k(object): wcur.close() cur.close() - def _get_parsers(self, ptop, have, abspath): + def _get_parsers( + self, ptop: str, have: dict[str, Union[str, float]], abspath: str + ) -> dict[str, MParser]: try: all_parsers = self.mtp_parsers[ptop] except: @@ -1030,16 +1082,16 @@ class Up2k(object): parsers = {k: v for k, v in parsers.items() if v.force or k not in have} return parsers - def _start_mpool(self): + def _start_mpool(self) -> Queue[Mpqe]: # mp.pool.ThreadPool and concurrent.futures.ThreadPoolExecutor # both do crazy runahead so lets reinvent another wheel nw = max(1, self.args.mtag_mt) - - if self.pending_tags is None: + assert self.mtag + if not self.mpool_used: + self.mpool_used = True self.log("using {}x {}".format(nw, self.mtag.backend)) - self.pending_tags = [] - mpool = Queue(nw) + mpool: Queue[Mpqe] = Queue(nw) for _ in range(nw): thr = threading.Thread( target=self._tag_thr, args=(mpool,), name="up2k-mpool" @@ -1049,50 +1101,55 @@ class Up2k(object): return mpool - def _stop_mpool(self, mpool): + def _stop_mpool(self, mpool: Queue[Mpqe]) -> None: if not mpool: return for _ in range(mpool.maxsize): - mpool.put(None) + mpool.put(Mpqe({}, set(), "", "")) mpool.join() - def _tag_thr(self, q): + def _tag_thr(self, q: Queue[Mpqe]) -> None: + assert self.mtag while True: - task = q.get() - if not task: + qe = q.get() + if not qe.w: q.task_done() return try: - parser, entags, wark, abspath = task - if parser == "mtag": - tags = self.mtag.get(abspath) + if not qe.mtp: + tags = self.mtag.get(qe.abspath) else: - tags = self.mtag.get_bin(parser, abspath) + tags = self.mtag.get_bin(qe.mtp, qe.abspath) vtags = [ "\033[36m{} \033[33m{}".format(k, v) for k, v in tags.items() ] if vtags: - self.log("{}\033[0m [{}]".format(" ".join(vtags), abspath)) + self.log("{}\033[0m [{}]".format(" ".join(vtags), qe.abspath)) with self.mutex: - self.pending_tags.append([entags, wark, abspath, tags]) + self.pending_tags.append((qe.entags, qe.w, qe.abspath, tags)) except: ex = traceback.format_exc() - if parser == "mtag": - parser = self.mtag.backend - - self._log_tag_err(parser, abspath, ex) + self._log_tag_err(qe.mtp or self.mtag.backend, qe.abspath, ex) q.task_done() - def _log_tag_err(self, parser, abspath, ex): + def _log_tag_err(self, parser: Any, abspath: str, ex: Any) -> None: msg = "{} failed to read tags from {}:\n{}".format(parser, abspath, ex) self.log(msg.lstrip(), c=1 if " int: + assert self.mtag if tags is None: try: tags = self.mtag.get(abspath) @@ -1127,12 +1184,12 @@ class Up2k(object): return ret - def _orz(self, db_path): + def _orz(self, db_path: str) -> "sqlite3.Cursor": timeout = int(max(self.args.srch_time, 5) * 1.2) return sqlite3.connect(db_path, timeout, check_same_thread=False).cursor() # x.set_trace_callback(trace) - def _open_db(self, db_path): + def _open_db(self, db_path: str) -> "sqlite3.Cursor": existed = bos.path.exists(db_path) cur = self._orz(db_path) ver = self._read_ver(cur) @@ -1141,8 +1198,8 @@ class Up2k(object): if ver == 4: try: - m = "creating backup before upgrade: " - cur = self._backup_db(db_path, cur, ver, m) + t = "creating backup before upgrade: " + cur = self._backup_db(db_path, cur, ver, t) self._upgrade_v4(cur) ver = 5 except: @@ -1157,8 +1214,8 @@ class Up2k(object): self.log("WARN: could not list files; DB corrupt?\n" + min_ex()) if (ver or 0) > DB_VER: - m = "database is version {}, this copyparty only supports versions <= {}" - raise Exception(m.format(ver, DB_VER)) + t = "database is version {}, this copyparty only supports versions <= {}" + raise Exception(t.format(ver, DB_VER)) msg = "creating new DB (old is bad); backup: " if ver: @@ -1171,7 +1228,9 @@ class Up2k(object): bos.unlink(db_path) return self._create_db(db_path, None) - def _backup_db(self, db_path, cur, ver, msg): + def _backup_db( + self, db_path: str, cur: "sqlite3.Cursor", ver: Optional[int], msg: str + ) -> "sqlite3.Cursor": bak = "{}.bak.{:x}.v{}".format(db_path, int(time.time()), ver) self.log(msg + bak) try: @@ -1180,8 +1239,8 @@ class Up2k(object): cur.connection.backup(c2) return cur except: - m = "native sqlite3 backup failed; using fallback method:\n" - self.log(m + min_ex()) + t = "native sqlite3 backup failed; using fallback method:\n" + self.log(t + min_ex()) finally: c2.close() @@ -1192,7 +1251,7 @@ class Up2k(object): shutil.copy2(fsenc(db_path), fsenc(bak)) return self._orz(db_path) - def _read_ver(self, cur): + def _read_ver(self, cur: "sqlite3.Cursor") -> Optional[int]: for tab in ["ki", "kv"]: try: c = cur.execute(r"select v from {} where k = 'sver'".format(tab)) @@ -1202,8 +1261,11 @@ class Up2k(object): rows = c.fetchall() if rows: return int(rows[0][0]) + return None - def _create_db(self, db_path, cur): + def _create_db( + self, db_path: str, cur: Optional["sqlite3.Cursor"] + ) -> "sqlite3.Cursor": """ collision in 2^(n/2) files where n = bits (6 bits/ch) 10*6/2 = 2^30 = 1'073'741'824, 24.1mb idx 1<<(3*10) @@ -1236,7 +1298,7 @@ class Up2k(object): self.log("created DB at {}".format(db_path)) return cur - def _upgrade_v4(self, cur): + def _upgrade_v4(self, cur: "sqlite3.Cursor") -> None: for cmd in [ r"alter table up add column ip text", r"alter table up add column at int", @@ -1247,7 +1309,7 @@ class Up2k(object): cur.connection.commit() - def handle_json(self, cj): + def handle_json(self, cj: dict[str, Any]) -> dict[str, Any]: with self.mutex: if not self.register_vpath(cj["ptop"], cj["vcfg"]): if cj["ptop"] not in self.registry: @@ -1269,13 +1331,13 @@ class Up2k(object): if cur: if self.no_expr_idx: q = r"select * from up where w = ?" - argv = (wark,) + argv = [wark] else: q = r"select * from up where substr(w,1,16) = ? and w = ?" - argv = (wark[:16], wark) + argv = [wark[:16], wark] - alts = [] - cur = cur.execute(q, argv) + alts: list[tuple[int, int, dict[str, Any]]] = [] + cur = cur.execute(q, tuple(argv)) for _, dtime, dsize, dp_dir, dp_fn, ip, at in cur: if dp_dir.startswith("//") or dp_fn.startswith("//"): dp_dir, dp_fn = s3dec(dp_dir, dp_fn) @@ -1307,7 +1369,7 @@ class Up2k(object): + (2 if dp_dir == cj["prel"] else 0) + (1 if dp_fn == cj["name"] else 0) ) - alts.append([score, -len(alts), j]) + alts.append((score, -len(alts), j)) job = sorted(alts, reverse=True)[0][2] if alts else None if job and wark in reg: @@ -1344,7 +1406,7 @@ class Up2k(object): # registry is size-constrained + can only contain one unique wark; # let want_recheck trigger symlink (if still in reg) or reupload if cur: - dupe = [cj["prel"], cj["name"], cj["lmod"]] + dupe = (cj["prel"], cj["name"], cj["lmod"]) try: self.dupesched[src].append(dupe) except: @@ -1431,17 +1493,19 @@ class Up2k(object): "wark": wark, } - def _untaken(self, fdir, fname, ts, ip): + def _untaken(self, fdir: str, fname: str, ts: float, ip: str) -> str: if self.args.nw: return fname # TODO broker which avoid this race and # provides a new filename if taken (same as bup) suffix = "-{:.6f}-{}".format(ts, ip.replace(":", ".")) - with ren_open(fname, "wb", fdir=fdir, suffix=suffix) as f: - return f["orz"][1] + with ren_open(fname, "wb", fdir=fdir, suffix=suffix) as zfw: + return zfw["orz"][1] - def _symlink(self, src, dst, verbose=True, lmod=None): + def _symlink( + self, src: str, dst: str, verbose: bool = True, lmod: float = 0 + ) -> None: if verbose: self.log("linking dupe:\n {0}\n {1}".format(src, dst)) @@ -1475,9 +1539,9 @@ class Up2k(object): break nc += 1 if nc > 1: - lsrc = nsrc[nc:] + zsl = nsrc[nc:] hops = len(ndst[nc:]) - 1 - lsrc = "../" * hops + "/".join(lsrc) + lsrc = "../" * hops + "/".join(zsl) try: if self.args.hardlink: @@ -1498,11 +1562,13 @@ class Up2k(object): if lmod and (not linked or SYMTIME): times = (int(time.time()), int(lmod)) if ANYWIN: - self.lastmod_q.append([dst, 0, times]) + self.lastmod_q.append((dst, 0, times)) else: bos.utime(dst, times, False) - def handle_chunk(self, ptop, wark, chash): + def handle_chunk( + self, ptop: str, wark: str, chash: str + ) -> tuple[int, list[int], str, float]: with self.mutex: job = self.registry[ptop].get(wark) if not job: @@ -1523,8 +1589,8 @@ class Up2k(object): if chash in job["busy"]: nh = len(job["hash"]) idx = job["hash"].index(chash) - m = "that chunk is already being written to:\n {}\n {} {}/{}\n {}" - raise Pebkac(400, m.format(wark, chash, idx, nh, job["name"])) + t = "that chunk is already being written to:\n {}\n {} {}/{}\n {}" + raise Pebkac(400, t.format(wark, chash, idx, nh, job["name"])) job["busy"][chash] = 1 @@ -1535,17 +1601,17 @@ class Up2k(object): path = os.path.join(job["ptop"], job["prel"], job["tnam"]) - return [chunksize, ofs, path, job["lmod"]] + return chunksize, ofs, path, job["lmod"] - def release_chunk(self, ptop, wark, chash): + def release_chunk(self, ptop: str, wark: str, chash: str) -> bool: with self.mutex: job = self.registry[ptop].get(wark) if job: job["busy"].pop(chash, None) - return [True] + return True - def confirm_chunk(self, ptop, wark, chash): + def confirm_chunk(self, ptop: str, wark: str, chash: str) -> tuple[int, str]: with self.mutex: try: job = self.registry[ptop][wark] @@ -1553,14 +1619,14 @@ class Up2k(object): src = os.path.join(pdir, job["tnam"]) dst = os.path.join(pdir, job["name"]) except Exception as ex: - return "confirm_chunk, wark, " + repr(ex) + return "confirm_chunk, wark, " + repr(ex) # type: ignore job["busy"].pop(chash, None) try: job["need"].remove(chash) except Exception as ex: - return "confirm_chunk, chash, " + repr(ex) + return "confirm_chunk, chash, " + repr(ex) # type: ignore ret = len(job["need"]) if ret > 0: @@ -1576,35 +1642,35 @@ class Up2k(object): return ret, dst - def finish_upload(self, ptop, wark): + def finish_upload(self, ptop: str, wark: str) -> None: with self.mutex: self._finish_upload(ptop, wark) - def _finish_upload(self, ptop, wark): + def _finish_upload(self, ptop: str, wark: str) -> None: try: job = self.registry[ptop][wark] pdir = os.path.join(job["ptop"], job["prel"]) src = os.path.join(pdir, job["tnam"]) dst = os.path.join(pdir, job["name"]) except Exception as ex: - return "finish_upload, wark, " + repr(ex) + raise Pebkac(500, "finish_upload, wark, " + repr(ex)) # self.log("--- " + wark + " " + dst + " finish_upload atomic " + dst, 4) atomic_move(src, dst) times = (int(time.time()), int(job["lmod"])) if ANYWIN: - a = [dst, job["size"], times] - self.lastmod_q.append(a) + z1 = (dst, job["size"], times) + self.lastmod_q.append(z1) elif not job["hash"]: try: bos.utime(dst, times) except: pass - a = [job[x] for x in "ptop wark prel name lmod size addr".split()] - a += [job.get("at") or time.time()] - if self.idx_wark(*a): + z2 = [job[x] for x in "ptop wark prel name lmod size addr".split()] + z2 += [job.get("at") or time.time()] + if self.idx_wark(*z2): del self.registry[ptop][wark] else: self.regdrop(ptop, wark) @@ -1622,27 +1688,37 @@ class Up2k(object): self._symlink(dst, d2, lmod=lmod) if cur: self.db_rm(cur, rd, fn) - self.db_add(cur, wark, rd, fn, *a[-4:]) + self.db_add(cur, wark, rd, fn, *z2[-4:]) if cur: cur.connection.commit() - def regdrop(self, ptop, wark): - t = self.droppable[ptop] + def regdrop(self, ptop: str, wark: str) -> None: + olds = self.droppable[ptop] if wark: - t.append(wark) + olds.append(wark) - if len(t) <= self.args.reg_cap: + if len(olds) <= self.args.reg_cap: return - n = len(t) - int(self.args.reg_cap / 2) - m = "up2k-registry [{}] has {} droppables; discarding {}" - self.log(m.format(ptop, len(t), n)) - for k in t[:n]: + n = len(olds) - int(self.args.reg_cap / 2) + t = "up2k-registry [{}] has {} droppables; discarding {}" + self.log(t.format(ptop, len(olds), n)) + for k in olds[:n]: self.registry[ptop].pop(k, None) - self.droppable[ptop] = t[n:] + self.droppable[ptop] = olds[n:] - def idx_wark(self, ptop, wark, rd, fn, lmod, sz, ip, at): + def idx_wark( + self, + ptop: str, + wark: str, + rd: str, + fn: str, + lmod: float, + sz: int, + ip: str, + at: float, + ) -> bool: cur = self.cur.get(ptop) if not cur: return False @@ -1652,29 +1728,41 @@ class Up2k(object): cur.connection.commit() if "e2t" in self.flags[ptop]: - self.tagq.put([ptop, wark, rd, fn]) + self.tagq.put((ptop, wark, rd, fn)) self.n_tagq += 1 return True - def db_rm(self, db, rd, fn): + def db_rm(self, db: "sqlite3.Cursor", rd: str, fn: str) -> None: sql = "delete from up where rd = ? and fn = ?" try: db.execute(sql, (rd, fn)) except: + assert self.mem_cur db.execute(sql, s3enc(self.mem_cur, rd, fn)) - def db_add(self, db, wark, rd, fn, ts, sz, ip, at): + def db_add( + self, + db: "sqlite3.Cursor", + wark: str, + rd: str, + fn: str, + ts: float, + sz: int, + ip: str, + at: float, + ) -> None: sql = "insert into up values (?,?,?,?,?,?,?)" v = (wark, int(ts), sz, rd, fn, ip or "", int(at or 0)) try: db.execute(sql, v) except: + assert self.mem_cur rd, fn = s3enc(self.mem_cur, rd, fn) v = (wark, int(ts), sz, rd, fn, ip or "", int(at or 0)) db.execute(sql, v) - def handle_rm(self, uname, ip, vpaths): + def handle_rm(self, uname: str, ip: str, vpaths: list[str]) -> str: n_files = 0 ok = {} ng = {} @@ -1687,12 +1775,14 @@ class Up2k(object): ng[k] = 1 ng = {k: 1 for k in ng if k not in ok} - ok = len(ok) - ng = len(ng) + iok = len(ok) + ing = len(ng) - return "deleted {} files (and {}/{} folders)".format(n_files, ok, ok + ng) + return "deleted {} files (and {}/{} folders)".format(n_files, iok, iok + ing) - def _handle_rm(self, uname, ip, vpath): + def _handle_rm( + self, uname: str, ip: str, vpath: str + ) -> tuple[int, list[str], list[str]]: try: permsets = [[True, False, False, True]] vn, rem = self.asrv.vfs.get(vpath, uname, *permsets[0]) @@ -1709,18 +1799,18 @@ class Up2k(object): vn, rem = vn.get_dbv(rem) _, _, _, _, dip, dat = self._find_from_vpath(vn.realpath, rem) - m = "you cannot delete this: " + t = "you cannot delete this: " if not dip: - m += "file not found" + t += "file not found" elif dip != ip: - m += "not uploaded by (You)" + t += "not uploaded by (You)" elif dat < time.time() - self.args.unpost: - m += "uploaded too long ago" + t += "uploaded too long ago" else: - m = None + t = "" - if m: - raise Pebkac(400, m) + if t: + raise Pebkac(400, t) ptop = vn.realpath atop = vn.canonical(rem, False) @@ -1731,16 +1821,19 @@ class Up2k(object): raise Pebkac(400, "file not found on disk (already deleted?)") scandir = not self.args.no_scandir - if stat.S_ISLNK(st.st_mode) or stat.S_ISREG(st.st_mode): + if stat.S_ISDIR(st.st_mode): + g = vn.walk("", rem, [], uname, permsets, True, scandir, True) + if unpost: + raise Pebkac(400, "cannot unpost folders") + elif stat.S_ISLNK(st.st_mode) or stat.S_ISREG(st.st_mode): dbv, vrem = self.asrv.vfs.get(vpath, uname, *permsets[0]) dbv, vrem = dbv.get_dbv(vrem) voldir = vsplit(vrem)[0] vpath_dir = vsplit(vpath)[0] - g = [[dbv, voldir, vpath_dir, adir, [[fn, 0]], [], []]] + g = [(dbv, voldir, vpath_dir, adir, [(fn, 0)], [], {})] # type: ignore else: - g = vn.walk("", rem, [], uname, permsets, True, scandir, True) - if unpost: - raise Pebkac(400, "cannot unpost folders") + self.log("rm: skip type-{:x} file [{}]".format(st.st_mode, atop)) + return 0, [], [] n_files = 0 for dbv, vrem, _, adir, files, rd, vd in g: @@ -1766,7 +1859,7 @@ class Up2k(object): rm = rmdirs(self.log_func, scandir, True, atop, 1) return n_files, rm[0], rm[1] - def handle_mv(self, uname, svp, dvp): + def handle_mv(self, uname: str, svp: str, dvp: str) -> str: svn, srem = self.asrv.vfs.get(svp, uname, True, False, True) svn, srem = svn.get_dbv(srem) sabs = svn.canonical(srem, False) @@ -1808,7 +1901,7 @@ class Up2k(object): rmdirs(self.log_func, scandir, True, sabs, 1) return "k" - def _mv_file(self, uname, svp, dvp): + def _mv_file(self, uname: str, svp: str, dvp: str) -> str: svn, srem = self.asrv.vfs.get(svp, uname, True, False, True) svn, srem = svn.get_dbv(srem) @@ -1834,29 +1927,33 @@ class Up2k(object): if bos.path.islink(sabs): dlabs = absreal(sabs) - m = "moving symlink from [{}] to [{}], target [{}]" - self.log(m.format(sabs, dabs, dlabs)) + t = "moving symlink from [{}] to [{}], target [{}]" + self.log(t.format(sabs, dabs, dlabs)) mt = bos.path.getmtime(sabs, False) bos.unlink(sabs) self._symlink(dlabs, dabs, False, lmod=mt) # folders are too scary, schedule rescan of both vols - self.need_rescan[svn.vpath] = 1 - self.need_rescan[dvn.vpath] = 1 + self.need_rescan.add(svn.vpath) + self.need_rescan.add(dvn.vpath) with self.rescan_cond: self.rescan_cond.notify_all() return "k" - c1, w, ftime, fsize, ip, at = self._find_from_vpath(svn.realpath, srem) + c1, w, ftime_, fsize_, ip, at = self._find_from_vpath(svn.realpath, srem) c2 = self.cur.get(dvn.realpath) - if ftime is None: + if ftime_ is None: st = bos.stat(sabs) ftime = st.st_mtime fsize = st.st_size + else: + ftime = ftime_ + fsize = fsize_ or 0 if w: + assert c1 if c2 and c2 != c1: self._copy_tags(c1, c2, w) @@ -1865,7 +1962,7 @@ class Up2k(object): c1.connection.commit() if c2: - self.db_add(c2, w, drd, dfn, ftime, fsize, ip, at) + self.db_add(c2, w, drd, dfn, ftime, fsize, ip or "", at or 0) c2.connection.commit() else: self.log("not found in src db: [{}]".format(svp)) @@ -1873,7 +1970,9 @@ class Up2k(object): bos.rename(sabs, dabs) return "k" - def _copy_tags(self, csrc, cdst, wark): + def _copy_tags( + self, csrc: "sqlite3.Cursor", cdst: "sqlite3.Cursor", wark: str + ) -> None: """copy all tags for wark from src-db to dst-db""" w = wark[:16] @@ -1883,16 +1982,26 @@ class Up2k(object): for _, k, v in csrc.execute("select * from mt where w=?", (w,)): cdst.execute("insert into mt values(?,?,?)", (w, k, v)) - def _find_from_vpath(self, ptop, vrem): + def _find_from_vpath( + self, ptop: str, vrem: str + ) -> tuple[ + Optional["sqlite3.Cursor"], + Optional[str], + Optional[int], + Optional[int], + Optional[str], + Optional[int], + ]: cur = self.cur.get(ptop) if not cur: - return [None] * 6 + return None, None, None, None, None, None rd, fn = vsplit(vrem) q = "select w, mt, sz, ip, at from up where rd=? and fn=? limit 1" try: c = cur.execute(q, (rd, fn)) except: + assert self.mem_cur c = cur.execute(q, s3enc(self.mem_cur, rd, fn)) hit = c.fetchone() @@ -1901,14 +2010,21 @@ class Up2k(object): return cur, wark, ftime, fsize, ip, at return cur, None, None, None, None, None - def _forget_file(self, ptop, vrem, cur, wark, drop_tags): + def _forget_file( + self, + ptop: str, + vrem: str, + cur: Optional["sqlite3.Cursor"], + wark: Optional[str], + drop_tags: bool, + ) -> None: """forgets file in db, fixes symlinks, does not delete""" srd, sfn = vsplit(vrem) self.log("forgetting {}".format(vrem)) - if wark: + if wark and cur: self.log("found {} in db".format(wark)) if drop_tags: - if self._relink(wark, ptop, vrem, None): + if self._relink(wark, ptop, vrem, ""): drop_tags = False if drop_tags: @@ -1919,20 +2035,24 @@ class Up2k(object): reg = self.registry.get(ptop) if reg: - if not wark: - wark = [ + vdir = vsplit(vrem)[0] + wark = wark or next( + ( x for x, y in reg.items() - if sfn in [y["name"], y.get("tnam")] and y["prel"] == vrem - ] - - if wark and wark in reg: - m = "forgetting partial upload {} ({})" - p = self._vis_job_progress(wark) - self.log(m.format(wark, p)) + if sfn in [y["name"], y.get("tnam")] and y["prel"] == vdir + ), + "", + ) + job = reg.get(wark) if wark else None + if job: + t = "forgetting partial upload {} ({})" + p = self._vis_job_progress(job) + self.log(t.format(wark, p)) + assert wark del reg[wark] - def _relink(self, wark, sptop, srem, dabs): + def _relink(self, wark: str, sptop: str, srem: str, dabs: str) -> int: """ update symlinks from file at svn/srem to dabs (rename), or to first remaining full if no dabs (delete) @@ -1953,13 +2073,13 @@ class Up2k(object): if not dupes: return 0 - full = {} - links = {} + full: dict[str, tuple[str, str]] = {} + links: dict[str, tuple[str, str]] = {} for ptop, vp in dupes: ap = os.path.join(ptop, vp) try: d = links if bos.path.islink(ap) else full - d[ap] = [ptop, vp] + d[ap] = (ptop, vp) except: self.log("relink: not found: [{}]".format(ap)) @@ -1973,13 +2093,13 @@ class Up2k(object): bos.rename(sabs, slabs) bos.utime(slabs, (int(time.time()), int(mt)), False) self._symlink(slabs, sabs, False) - full[slabs] = [ptop, rem] + full[slabs] = (ptop, rem) sabs = slabs if not dabs: dabs = list(sorted(full.keys()))[0] - for alink in links.keys(): + for alink in links: lmod = None try: if alink != sabs and absreal(alink) != sabs: @@ -1991,11 +2111,11 @@ class Up2k(object): except: pass - self._symlink(dabs, alink, False, lmod=lmod) + self._symlink(dabs, alink, False, lmod=lmod or 0) return len(full) + len(links) - def _get_wark(self, cj): + def _get_wark(self, cj: dict[str, Any]) -> str: if len(cj["name"]) > 1024 or len(cj["hash"]) > 512 * 1024: # 16TiB raise Pebkac(400, "name or numchunks not according to spec") @@ -2020,15 +2140,14 @@ class Up2k(object): return wark - def _hashlist_from_file(self, path): - pp = self.pp if hasattr(self, "pp") else None + def _hashlist_from_file(self, path: str) -> list[str]: fsz = bos.path.getsize(path) csz = up2k_chunksize(fsz) ret = [] with open(fsenc(path), "rb", 512 * 1024) as f: while fsz > 0: - if pp: - pp.msg = "{} MB, {}".format(int(fsz / 1024 / 1024), path) + if self.pp: + self.pp.msg = "{} MB, {}".format(int(fsz / 1024 / 1024), path) hashobj = hashlib.sha512() rem = min(csz, fsz) @@ -2047,7 +2166,7 @@ class Up2k(object): return ret - def _new_upload(self, job): + def _new_upload(self, job: dict[str, Any]) -> None: self.registry[job["ptop"]][job["wark"]] = job pdir = os.path.join(job["ptop"], job["prel"]) job["name"] = self._untaken(pdir, job["name"], job["t0"], job["addr"]) @@ -2066,8 +2185,8 @@ class Up2k(object): dip = job["addr"].replace(":", ".") suffix = "-{:.6f}-{}".format(job["t0"], dip) - with ren_open(tnam, "wb", fdir=pdir, suffix=suffix) as f: - f, job["tnam"] = f["orz"] + with ren_open(tnam, "wb", fdir=pdir, suffix=suffix) as zfw: + f, job["tnam"] = zfw["orz"] if ( ANYWIN and self.args.sparse @@ -2086,7 +2205,7 @@ class Up2k(object): if not job["hash"]: self._finish_upload(job["ptop"], job["wark"]) - def _lastmodder(self): + def _lastmodder(self) -> None: while True: ready = self.lastmod_q self.lastmod_q = [] @@ -2098,8 +2217,8 @@ class Up2k(object): try: bos.utime(path, times, False) except: - m = "lmod: failed to utime ({}, {}):\n{}" - self.log(m.format(path, times, min_ex())) + t = "lmod: failed to utime ({}, {}):\n{}" + self.log(t.format(path, times, min_ex())) if self.args.sparse and self.args.sparse * 1024 * 1024 <= sz: try: @@ -2107,21 +2226,22 @@ class Up2k(object): except: self.log("could not unsparse [{}]".format(path), 3) - def _snapshot(self): - self.snap_persist_interval = 300 # persist unfinished index every 5 min - self.snap_discard_interval = 21600 # drop unfinished after 6 hours inactivity - self.snap_prev = {} + def _snapshot(self) -> None: + slp = self.snap_persist_interval while True: - time.sleep(self.snap_persist_interval) - if not hasattr(self, "pp"): + time.sleep(slp) + if self.pp: + slp = 5 + else: + slp = self.snap_persist_interval self.do_snapshot() - def do_snapshot(self): + def do_snapshot(self) -> None: with self.mutex: for k, reg in self.registry.items(): self._snap_reg(k, reg) - def _snap_reg(self, ptop, reg): + def _snap_reg(self, ptop: str, reg: dict[str, dict[str, Any]]) -> None: now = time.time() histpath = self.asrv.vfs.histtab.get(ptop) if not histpath: @@ -2133,9 +2253,9 @@ class Up2k(object): if x["need"] and now - x["poke"] > self.snap_discard_interval ] if rm: - m = "dropping {} abandoned uploads in {}".format(len(rm), ptop) + t = "dropping {} abandoned uploads in {}".format(len(rm), ptop) vis = [self._vis_job_progress(x) for x in rm] - self.log("\n".join([m] + vis)) + self.log("\n".join([t] + vis)) for job in rm: del reg[job["wark"]] try: @@ -2159,8 +2279,8 @@ class Up2k(object): bos.unlink(path) return - newest = max(x["poke"] for _, x in reg.items()) if reg else 0 - etag = [len(reg), newest] + newest = float(max(x["poke"] for _, x in reg.items()) if reg else 0) + etag = (len(reg), newest) if etag == self.snap_prev.get(ptop): return @@ -2177,10 +2297,11 @@ class Up2k(object): self.log("snap: {} |{}|".format(path, len(reg.keys()))) self.snap_prev[ptop] = etag - def _tagger(self): + def _tagger(self) -> None: with self.mutex: self.n_tagq += 1 + assert self.mtag while True: with self.mutex: self.n_tagq -= 1 @@ -2218,7 +2339,7 @@ class Up2k(object): self.log("tagged {} ({}+{})".format(abspath, ntags1, len(tags) - ntags1)) - def _hasher(self): + def _hasher(self) -> None: with self.mutex: self.n_hashq += 1 @@ -2240,20 +2361,21 @@ class Up2k(object): with self.mutex: self.idx_wark(ptop, wark, rd, fn, inf.st_mtime, inf.st_size, ip, at) - def hash_file(self, ptop, flags, rd, fn, ip, at): + def hash_file( + self, ptop: str, flags: dict[str, Any], rd: str, fn: str, ip: str, at: float + ) -> None: with self.mutex: self.register_vpath(ptop, flags) - self.hashq.put([ptop, rd, fn, ip, at]) + self.hashq.put((ptop, rd, fn, ip, at)) self.n_hashq += 1 # self.log("hashq {} push {}/{}/{}".format(self.n_hashq, ptop, rd, fn)) - def shutdown(self): - if hasattr(self, "snap_prev"): - self.log("writing snapshot") - self.do_snapshot() + def shutdown(self) -> None: + self.log("writing snapshot") + self.do_snapshot() -def up2k_chunksize(filesize): +def up2k_chunksize(filesize: int) -> int: chunksize = 1024 * 1024 stepsize = 512 * 1024 while True: @@ -2266,18 +2388,17 @@ def up2k_chunksize(filesize): stepsize *= mul -def up2k_wark_from_hashlist(salt, filesize, hashes): +def up2k_wark_from_hashlist(salt: str, filesize: int, hashes: list[str]) -> str: """server-reproducible file identifier, independent of name or location""" - ident = [salt, str(filesize)] - ident.extend(hashes) - ident = "\n".join(ident) + values = [salt, str(filesize)] + hashes + vstr = "\n".join(values) - wark = hashlib.sha512(ident.encode("utf-8")).digest()[:33] + wark = hashlib.sha512(vstr.encode("utf-8")).digest()[:33] wark = base64.urlsafe_b64encode(wark) return wark.decode("ascii") -def up2k_wark_from_metadata(salt, sz, lastmod, rd, fn): +def up2k_wark_from_metadata(salt: str, sz: int, lastmod: int, rd: str, fn: str) -> str: ret = fsenc("{}\n{}\n{}\n{}\n{}".format(salt, lastmod, sz, rd, fn)) ret = base64.urlsafe_b64encode(hashlib.sha512(ret).digest()) return "#{}".format(ret.decode("ascii"))[:44] diff --git a/copyparty/util.py b/copyparty/util.py index a01fdfdb..2ca710f5 100644 --- a/copyparty/util.py +++ b/copyparty/util.py @@ -1,51 +1,79 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import re -import os -import sys -import stat -import time import base64 +import contextlib +import hashlib +import mimetypes +import os +import platform +import re import select -import struct import signal import socket -import hashlib -import platform -import traceback -import threading -import mimetypes -import contextlib +import stat +import struct import subprocess as sp # nosec -from datetime import datetime +import sys +import threading +import time +import traceback from collections import Counter +from datetime import datetime -from .__init__ import PY2, WINDOWS, ANYWIN, VT100 +from .__init__ import ANYWIN, PY2, TYPE_CHECKING, VT100, WINDOWS from .stolen import surrogateescape +try: + HAVE_SQLITE3 = True + import sqlite3 # pylint: disable=unused-import # typechk +except: + HAVE_SQLITE3 = False + +try: + import types + from collections.abc import Callable, Iterable + + import typing + from typing import Any, Generator, Optional, Protocol, Union + + class RootLogger(Protocol): + def __call__(self, src: str, msg: str, c: Union[int, str] = 0) -> None: + return None + + class NamedLogger(Protocol): + def __call__(self, msg: str, c: Union[int, str] = 0) -> None: + return None + +except: + pass + +if TYPE_CHECKING: + from .authsrv import VFS + + FAKE_MP = False try: - if FAKE_MP: - import multiprocessing.dummy as mp # noqa: F401 # pylint: disable=unused-import + if not FAKE_MP: + import multiprocessing as mp else: - import multiprocessing as mp # noqa: F401 # pylint: disable=unused-import + import multiprocessing.dummy as mp # type: ignore except ImportError: # support jython - mp = None + mp = None # type: ignore if not PY2: - from urllib.parse import unquote_to_bytes as unquote + from io import BytesIO from urllib.parse import quote_from_bytes as quote - from queue import Queue # pylint: disable=unused-import - from io import BytesIO # pylint: disable=unused-import + from urllib.parse import unquote_to_bytes as unquote else: - from urllib import unquote # pylint: disable=no-name-in-module + from StringIO import StringIO as BytesIO from urllib import quote # pylint: disable=no-name-in-module - from Queue import Queue # pylint: disable=import-error,no-name-in-module - from StringIO import StringIO as BytesIO # pylint: disable=unused-import + from urllib import unquote # pylint: disable=no-name-in-module +_: Any = (mp, BytesIO, quote, unquote) +__all__ = ["mp", "BytesIO", "quote", "unquote"] try: struct.unpack(b">i", b"idgi") @@ -53,20 +81,21 @@ try: sunpack = struct.unpack except: - def spack(f, *a, **ka): - return struct.pack(f.decode("ascii"), *a, **ka) + def spack(fmt: bytes, *a: Any) -> bytes: + return struct.pack(fmt.decode("ascii"), *a) - def sunpack(f, *a, **ka): - return struct.unpack(f.decode("ascii"), *a, **ka) + def sunpack(fmt: bytes, a: bytes) -> tuple[Any, ...]: + return struct.unpack(fmt.decode("ascii"), a) ansi_re = re.compile("\033\\[[^mK]*[mK]") surrogateescape.register_surrogateescape() -FS_ENCODING = sys.getfilesystemencoding() if WINDOWS and PY2: FS_ENCODING = "utf-8" +else: + FS_ENCODING = sys.getfilesystemencoding() SYMTIME = sys.version_info >= (3, 6) and os.utime in os.supports_follow_symlinks @@ -116,7 +145,7 @@ MIMES = { } -def _add_mimes(): +def _add_mimes() -> None: for ln in """text css html csv application json wasm xml pdf rtf zip image webp jpeg png gif bmp @@ -170,18 +199,18 @@ REKOBO_LKEY = {k.lower(): v for k, v in REKOBO_KEY.items()} class Cooldown(object): - def __init__(self, maxage): + def __init__(self, maxage: float) -> None: self.maxage = maxage self.mutex = threading.Lock() - self.hist = {} - self.oldest = 0 + self.hist: dict[str, float] = {} + self.oldest = 0.0 - def poke(self, key): + def poke(self, key: str) -> bool: with self.mutex: now = time.time() ret = False - pv = self.hist.get(key, 0) + pv: float = self.hist.get(key, 0) if now - pv > self.maxage: self.hist[key] = now ret = True @@ -204,12 +233,12 @@ class _Unrecv(object): undo any number of socket recv ops """ - def __init__(self, s, log): - self.s = s # type: socket.socket + def __init__(self, s: socket.socket, log: Optional[NamedLogger]) -> None: + self.s = s self.log = log - self.buf = b"" + self.buf: bytes = b"" - def recv(self, nbytes): + def recv(self, nbytes: int) -> bytes: if self.buf: ret = self.buf[:nbytes] self.buf = self.buf[nbytes:] @@ -221,25 +250,25 @@ class _Unrecv(object): return ret - def recv_ex(self, nbytes, raise_on_trunc=True): + def recv_ex(self, nbytes: int, raise_on_trunc: bool = True) -> bytes: """read an exact number of bytes""" ret = b"" try: while nbytes > len(ret): ret += self.recv(nbytes - len(ret)) except OSError: - m = "client only sent {} of {} expected bytes".format(len(ret), nbytes) + t = "client only sent {} of {} expected bytes".format(len(ret), nbytes) if len(ret) <= 16: - m += "; got {!r}".format(ret) + t += "; got {!r}".format(ret) if raise_on_trunc: - raise UnrecvEOF(5, m) + raise UnrecvEOF(5, t) elif self.log: - self.log(m, 3) + self.log(t, 3) return ret - def unrecv(self, buf): + def unrecv(self, buf: bytes) -> None: self.buf = buf + self.buf @@ -248,28 +277,28 @@ class _LUnrecv(object): with expensive debug logging """ - def __init__(self, s, log): + def __init__(self, s: socket.socket, log: Optional[NamedLogger]) -> None: self.s = s self.log = log self.buf = b"" - def recv(self, nbytes): + def recv(self, nbytes: int) -> bytes: if self.buf: ret = self.buf[:nbytes] self.buf = self.buf[nbytes:] - m = "\033[0;7mur:pop:\033[0;1;32m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m" - self.log(m.format(ret, self.buf)) + t = "\033[0;7mur:pop:\033[0;1;32m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m" + print(t.format(ret, self.buf)) return ret ret = self.s.recv(nbytes) - m = "\033[0;7mur:recv\033[0;1;33m {}\033[0m" - self.log(m.format(ret)) + t = "\033[0;7mur:recv\033[0;1;33m {}\033[0m" + print(t.format(ret)) if not ret: raise UnrecvEOF("client stopped sending data") return ret - def recv_ex(self, nbytes, raise_on_trunc=True): + def recv_ex(self, nbytes: int, raise_on_trunc: bool = True) -> bytes: """read an exact number of bytes""" try: ret = self.recv(nbytes) @@ -285,18 +314,18 @@ class _LUnrecv(object): err = True if err: - m = "client only sent {} of {} expected bytes".format(len(ret), nbytes) + t = "client only sent {} of {} expected bytes".format(len(ret), nbytes) if raise_on_trunc: - raise UnrecvEOF(m) + raise UnrecvEOF(t) elif self.log: - self.log(m, 3) + self.log(t, 3) return ret - def unrecv(self, buf): + def unrecv(self, buf: bytes) -> None: self.buf = buf + self.buf - m = "\033[0;7mur:push\033[0;1;31m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m" - self.log(m.format(buf, self.buf)) + t = "\033[0;7mur:push\033[0;1;31m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m" + print(t.format(buf, self.buf)) Unrecv = _Unrecv @@ -304,14 +333,14 @@ Unrecv = _Unrecv class FHC(object): class CE(object): - def __init__(self, fh): - self.ts = 0 + def __init__(self, fh: typing.BinaryIO) -> None: + self.ts: float = 0 self.fhs = [fh] - def __init__(self): - self.cache = {} + def __init__(self) -> None: + self.cache: dict[str, FHC.CE] = {} - def close(self, path): + def close(self, path: str) -> None: try: ce = self.cache[path] except: @@ -322,7 +351,7 @@ class FHC(object): del self.cache[path] - def clean(self): + def clean(self) -> None: if not self.cache: return @@ -337,10 +366,10 @@ class FHC(object): self.cache = keep - def pop(self, path): + def pop(self, path: str) -> typing.BinaryIO: return self.cache[path].fhs.pop() - def put(self, path, fh): + def put(self, path: str, fh: typing.BinaryIO) -> None: try: ce = self.cache[path] ce.fhs.append(fh) @@ -356,14 +385,15 @@ class ProgressPrinter(threading.Thread): periodically print progress info without linefeeds """ - def __init__(self): + def __init__(self) -> None: threading.Thread.__init__(self, name="pp") self.daemon = True - self.msg = None + self.msg = "" self.end = False + self.n = -1 self.start() - def run(self): + def run(self) -> None: msg = None fmt = " {}\033[K\r" if VT100 else " {} $\r" while not self.end: @@ -384,7 +414,7 @@ class ProgressPrinter(threading.Thread): sys.stdout.flush() # necessary on win10 even w/ stderr btw -def uprint(msg): +def uprint(msg: str) -> None: try: print(msg, end="") except UnicodeEncodeError: @@ -394,17 +424,17 @@ def uprint(msg): print(msg.encode("ascii", "replace").decode(), end="") -def nuprint(msg): +def nuprint(msg: str) -> None: uprint("{}\n".format(msg)) -def rice_tid(): +def rice_tid() -> str: tid = threading.current_thread().ident c = sunpack(b"B" * 5, spack(b">Q", tid)[-5:]) return "".join("\033[1;37;48;5;{0}m{0:02x}".format(x) for x in c) + "\033[0m" -def trace(*args, **kwargs): +def trace(*args: Any, **kwargs: Any) -> None: t = time.time() stack = "".join( "\033[36m{}\033[33m{}".format(x[0].split(os.sep)[-1][:-3], x[1]) @@ -423,15 +453,15 @@ def trace(*args, **kwargs): nuprint(msg) -def alltrace(): - threads = {} +def alltrace() -> str: + threads: dict[str, types.FrameType] = {} names = dict([(t.ident, t.name) for t in threading.enumerate()]) for tid, stack in sys._current_frames().items(): name = "{} ({:x})".format(names.get(tid), tid) threads[name] = stack - rret = [] - bret = [] + rret: list[str] = [] + bret: list[str] = [] for name, stack in sorted(threads.items()): ret = ["\n\n# {}".format(name)] pad = None @@ -451,20 +481,20 @@ def alltrace(): return "\n".join(rret + bret) -def start_stackmon(arg_str, nid): +def start_stackmon(arg_str: str, nid: int) -> None: suffix = "-{}".format(nid) if nid else "" fp, f = arg_str.rsplit(",", 1) - f = int(f) + zi = int(f) t = threading.Thread( target=stackmon, - args=(fp, f, suffix), + args=(fp, zi, suffix), name="stackmon" + suffix, ) t.daemon = True t.start() -def stackmon(fp, ival, suffix): +def stackmon(fp: str, ival: float, suffix: str) -> None: ctr = 0 while True: ctr += 1 @@ -474,7 +504,9 @@ def stackmon(fp, ival, suffix): f.write(st.encode("utf-8", "replace")) -def start_log_thrs(logger, ival, nid): +def start_log_thrs( + logger: Callable[[str, str, int], None], ival: float, nid: int +) -> None: ival = float(ival) tname = lname = "log-thrs" if nid: @@ -490,7 +522,7 @@ def start_log_thrs(logger, ival, nid): t.start() -def log_thrs(log, ival, name): +def log_thrs(log: Callable[[str, str, int], None], ival: float, name: str) -> None: while True: time.sleep(ival) tv = [x.name for x in threading.enumerate()] @@ -507,7 +539,7 @@ def log_thrs(log, ival, name): log(name, "\033[0m \033[33m".join(tv), 3) -def vol_san(vols, txt): +def vol_san(vols: list["VFS"], txt: bytes) -> bytes: for vol in vols: txt = txt.replace(vol.realpath.encode("utf-8"), vol.vpath.encode("utf-8")) txt = txt.replace( @@ -518,24 +550,26 @@ def vol_san(vols, txt): return txt -def min_ex(max_lines=8, reverse=False): +def min_ex(max_lines: int = 8, reverse: bool = False) -> str: et, ev, tb = sys.exc_info() - tb = traceback.extract_tb(tb) + stb = traceback.extract_tb(tb) fmt = "{} @ {} <{}>: {}" - ex = [fmt.format(fp.split(os.sep)[-1], ln, fun, txt) for fp, ln, fun, txt in tb] - ex.append("[{}] {}".format(et.__name__, ev)) + ex = [fmt.format(fp.split(os.sep)[-1], ln, fun, txt) for fp, ln, fun, txt in stb] + ex.append("[{}] {}".format(et.__name__ if et else "(anonymous)", ev)) return "\n".join(ex[-max_lines:][:: -1 if reverse else 1]) @contextlib.contextmanager -def ren_open(fname, *args, **kwargs): +def ren_open( + fname: str, *args: Any, **kwargs: Any +) -> Generator[dict[str, tuple[typing.IO[Any], str]], None, None]: fun = kwargs.pop("fun", open) fdir = kwargs.pop("fdir", None) suffix = kwargs.pop("suffix", None) if fname == os.devnull: with fun(fname, *args, **kwargs) as f: - yield {"orz": [f, fname]} + yield {"orz": (f, fname)} return if suffix: @@ -575,7 +609,7 @@ def ren_open(fname, *args, **kwargs): with open(fsenc(fp2), "wb") as f2: f2.write(orig_name.encode("utf-8")) - yield {"orz": [f, fname]} + yield {"orz": (f, fname)} return except OSError as ex_: @@ -584,9 +618,9 @@ def ren_open(fname, *args, **kwargs): raise if not b64: - b64 = (bname + ext).encode("utf-8", "replace") - b64 = hashlib.sha512(b64).digest()[:12] - b64 = base64.urlsafe_b64encode(b64).decode("utf-8") + zs = (bname + ext).encode("utf-8", "replace") + zs = hashlib.sha512(zs).digest()[:12] + b64 = base64.urlsafe_b64encode(zs).decode("utf-8") badlen = len(fname) while len(fname) >= badlen: @@ -608,8 +642,8 @@ def ren_open(fname, *args, **kwargs): class MultipartParser(object): - def __init__(self, log_func, sr, http_headers): - self.sr = sr # type: Unrecv + def __init__(self, log_func: NamedLogger, sr: Unrecv, http_headers: dict[str, str]): + self.sr = sr self.log = log_func self.headers = http_headers @@ -622,10 +656,14 @@ class MultipartParser(object): r'^content-disposition:(?: *|.*; *)filename="(.*)"', re.IGNORECASE ) - self.boundary = None - self.gen = None + self.boundary = b"" + self.gen: Optional[ + Generator[ + tuple[str, Optional[str], Generator[bytes, None, None]], None, None + ] + ] = None - def _read_header(self): + def _read_header(self) -> tuple[str, Optional[str]]: """ returns [fieldname, filename] after eating a block of multipart headers while doing a decent job at dealing with the absolute mess that is @@ -641,7 +679,8 @@ class MultipartParser(object): # rfc-7578 overrides rfc-2388 so this is not-impl # (opera >=9 <11.10 is the only thing i've ever seen use it) raise Pebkac( - "you can't use that browser to upload multiple files at once" + 400, + "you can't use that browser to upload multiple files at once", ) continue @@ -655,12 +694,12 @@ class MultipartParser(object): raise Pebkac(400, "not form-data: {}".format(ln)) try: - field = self.re_cdisp_field.match(ln).group(1) + field = self.re_cdisp_field.match(ln).group(1) # type: ignore except: raise Pebkac(400, "missing field name: {}".format(ln)) try: - fn = self.re_cdisp_file.match(ln).group(1) + fn = self.re_cdisp_file.match(ln).group(1) # type: ignore except: # this is not a file upload, we're done return field, None @@ -687,11 +726,10 @@ class MultipartParser(object): esc = False for ch in fn: if esc: - if ch in ['"', "\\"]: - ret += '"' - else: - ret += esc + ch esc = False + if ch not in ['"', "\\"]: + ret += "\\" + ret += ch elif ch == "\\": esc = True elif ch == '"': @@ -699,9 +737,11 @@ class MultipartParser(object): else: ret += ch - return [field, ret] + return field, ret - def _read_data(self): + raise Pebkac(400, "server expected a multipart header but you never sent one") + + def _read_data(self) -> Generator[bytes, None, None]: blen = len(self.boundary) bufsz = 32 * 1024 while True: @@ -748,7 +788,9 @@ class MultipartParser(object): yield buf - def _run_gen(self): + def _run_gen( + self, + ) -> Generator[tuple[str, Optional[str], Generator[bytes, None, None]], None, None]: """ yields [fieldname, unsanitized_filename, fieldvalue] where fieldvalue yields chunks of data @@ -756,7 +798,7 @@ class MultipartParser(object): run = True while run: fieldname, filename = self._read_header() - yield [fieldname, filename, self._read_data()] + yield (fieldname, filename, self._read_data()) tail = self.sr.recv_ex(2, False) @@ -766,19 +808,19 @@ class MultipartParser(object): run = False if tail != b"\r\n": - m = "protocol error after field value: want b'\\r\\n', got {!r}" - raise Pebkac(400, m.format(tail)) + t = "protocol error after field value: want b'\\r\\n', got {!r}" + raise Pebkac(400, t.format(tail)) - def _read_value(self, iterator, max_len): + def _read_value(self, iterable: Iterable[bytes], max_len: int) -> bytes: ret = b"" - for buf in iterator: + for buf in iterable: ret += buf if len(ret) > max_len: raise Pebkac(400, "field length is too long") return ret - def parse(self): + def parse(self) -> None: # spec says there might be junk before the first boundary, # can't have the leading \r\n if that's not the case self.boundary = b"--" + get_boundary(self.headers).encode("utf-8") @@ -793,11 +835,12 @@ class MultipartParser(object): self.boundary = b"\r\n" + self.boundary self.gen = self._run_gen() - def require(self, field_name, max_len): + def require(self, field_name: str, max_len: int) -> str: """ returns the value of the next field in the multipart body, raises if the field name is not as expected """ + assert self.gen p_field, _, p_data = next(self.gen) if p_field != field_name: raise Pebkac( @@ -806,14 +849,15 @@ class MultipartParser(object): return self._read_value(p_data, max_len).decode("utf-8", "surrogateescape") - def drop(self): + def drop(self) -> None: """discards the remaining multipart body""" + assert self.gen for _, _, data in self.gen: for _ in data: pass -def get_boundary(headers): +def get_boundary(headers: dict[str, str]) -> str: # boundaries contain a-z A-Z 0-9 ' ( ) + _ , - . / : = ? # (whitespace allowed except as the last char) ptn = r"^multipart/form-data *; *(.*; *)?boundary=([^;]+)" @@ -825,14 +869,14 @@ def get_boundary(headers): return m.group(2) -def read_header(sr): +def read_header(sr: Unrecv) -> list[str]: ret = b"" while True: try: ret += sr.recv(1024) except: if not ret: - return None + return [] raise Pebkac( 400, @@ -853,7 +897,7 @@ def read_header(sr): return ret[:ofs].decode("utf-8", "surrogateescape").lstrip("\r\n").split("\r\n") -def gen_filekey(salt, fspath, fsize, inode): +def gen_filekey(salt: str, fspath: str, fsize: int, inode: int) -> str: return base64.urlsafe_b64encode( hashlib.sha512( "{} {} {} {}".format(salt, fspath, fsize, inode).encode("utf-8", "replace") @@ -861,7 +905,7 @@ def gen_filekey(salt, fspath, fsize, inode): ).decode("ascii") -def gencookie(k, v, dur): +def gencookie(k: str, v: str, dur: Optional[int]) -> str: v = v.replace(";", "") if dur: dt = datetime.utcfromtimestamp(time.time() + dur) @@ -872,7 +916,7 @@ def gencookie(k, v, dur): return "{}={}; Path=/; Expires={}; SameSite=Lax".format(k, v, exp) -def humansize(sz, terse=False): +def humansize(sz: float, terse: bool = False) -> str: for unit in ["B", "KiB", "MiB", "GiB", "TiB"]: if sz < 1024: break @@ -887,18 +931,18 @@ def humansize(sz, terse=False): return ret.replace("iB", "").replace(" ", "") -def unhumanize(sz): +def unhumanize(sz: str) -> int: try: - return float(sz) + return int(sz) except: pass - mul = sz[-1:].lower() - mul = {"k": 1024, "m": 1024 * 1024, "g": 1024 * 1024 * 1024}.get(mul, 1) - return float(sz[:-1]) * mul + mc = sz[-1:].lower() + mi = {"k": 1024, "m": 1024 * 1024, "g": 1024 * 1024 * 1024}.get(mc, 1) + return int(float(sz[:-1]) * mi) -def get_spd(nbyte, t0, t=None): +def get_spd(nbyte: int, t0: float, t: Optional[float] = None) -> str: if t is None: t = time.time() @@ -908,7 +952,7 @@ def get_spd(nbyte, t0, t=None): return "{} \033[0m{}/s\033[0m".format(s1, s2) -def s2hms(s, optional_h=False): +def s2hms(s: float, optional_h: bool = False) -> str: s = int(s) h, s = divmod(s, 3600) m, s = divmod(s, 60) @@ -918,7 +962,7 @@ def s2hms(s, optional_h=False): return "{}:{:02}:{:02}".format(h, m, s) -def uncyg(path): +def uncyg(path: str) -> str: if len(path) < 2 or not path.startswith("/"): return path @@ -928,8 +972,8 @@ def uncyg(path): return "{}:\\{}".format(path[1], path[3:]) -def undot(path): - ret = [] +def undot(path: str) -> str: + ret: list[str] = [] for node in path.split("/"): if node in ["", "."]: continue @@ -944,7 +988,7 @@ def undot(path): return "/".join(ret) -def sanitize_fn(fn, ok, bad): +def sanitize_fn(fn: str, ok: str, bad: list[str]) -> str: if "/" not in ok: fn = fn.replace("\\", "/").split("/")[-1] @@ -976,7 +1020,7 @@ def sanitize_fn(fn, ok, bad): return fn.strip() -def relchk(rp): +def relchk(rp: str) -> str: if ANYWIN: if "\n" in rp or "\r" in rp: return "x\nx" @@ -985,8 +1029,10 @@ def relchk(rp): if p != rp: return "[{}]".format(p) + return "" -def absreal(fpath): + +def absreal(fpath: str) -> str: try: return fsdec(os.path.abspath(os.path.realpath(fsenc(fpath)))) except: @@ -999,26 +1045,26 @@ def absreal(fpath): return os.path.abspath(os.path.realpath(fpath)) -def u8safe(txt): +def u8safe(txt: str) -> str: try: return txt.encode("utf-8", "xmlcharrefreplace").decode("utf-8", "replace") except: return txt.encode("utf-8", "replace").decode("utf-8", "replace") -def exclude_dotfiles(filepaths): +def exclude_dotfiles(filepaths: list[str]) -> list[str]: return [x for x in filepaths if not x.split("/")[-1].startswith(".")] -def http_ts(ts): +def http_ts(ts: int) -> str: file_dt = datetime.utcfromtimestamp(ts) return file_dt.strftime(HTTP_TS_FMT) -def html_escape(s, quote=False, crlf=False): +def html_escape(s: str, quot: bool = False, crlf: bool = False) -> str: """html.escape but also newlines""" s = s.replace("&", "&").replace("<", "<").replace(">", ">") - if quote: + if quot: s = s.replace('"', """).replace("'", "'") if crlf: s = s.replace("\r", " ").replace("\n", " ") @@ -1026,10 +1072,10 @@ def html_escape(s, quote=False, crlf=False): return s -def html_bescape(s, quote=False, crlf=False): +def html_bescape(s: bytes, quot: bool = False, crlf: bool = False) -> bytes: """html.escape but bytestrings""" s = s.replace(b"&", b"&").replace(b"<", b"<").replace(b">", b">") - if quote: + if quot: s = s.replace(b'"', b""").replace(b"'", b"'") if crlf: s = s.replace(b"\r", b" ").replace(b"\n", b" ") @@ -1037,18 +1083,20 @@ def html_bescape(s, quote=False, crlf=False): return s -def quotep(txt): +def quotep(txt: str) -> str: """url quoter which deals with bytes correctly""" btxt = w8enc(txt) quot1 = quote(btxt, safe=b"/") if not PY2: - quot1 = quot1.encode("ascii") + quot2 = quot1.encode("ascii") + else: + quot2 = quot1 - quot2 = quot1.replace(b" ", b"+") - return w8dec(quot2) + quot3 = quot2.replace(b" ", b"+") + return w8dec(quot3) -def unquotep(txt): +def unquotep(txt: str) -> str: """url unquoter which deals with bytes correctly""" btxt = w8enc(txt) # btxt = btxt.replace(b"+", b" ") @@ -1056,14 +1104,14 @@ def unquotep(txt): return w8dec(unq2) -def vsplit(vpath): +def vsplit(vpath: str) -> tuple[str, str]: if "/" not in vpath: return "", vpath - return vpath.rsplit("/", 1) + return vpath.rsplit("/", 1) # type: ignore -def w8dec(txt): +def w8dec(txt: bytes) -> str: """decodes filesystem-bytes to wtf8""" if PY2: return surrogateescape.decodefilename(txt) @@ -1071,7 +1119,7 @@ def w8dec(txt): return txt.decode(FS_ENCODING, "surrogateescape") -def w8enc(txt): +def w8enc(txt: str) -> bytes: """encodes wtf8 to filesystem-bytes""" if PY2: return surrogateescape.encodefilename(txt) @@ -1079,12 +1127,12 @@ def w8enc(txt): return txt.encode(FS_ENCODING, "surrogateescape") -def w8b64dec(txt): +def w8b64dec(txt: str) -> str: """decodes base64(filesystem-bytes) to wtf8""" return w8dec(base64.urlsafe_b64decode(txt.encode("ascii"))) -def w8b64enc(txt): +def w8b64enc(txt: str) -> str: """encodes wtf8 to base64(filesystem-bytes)""" return base64.urlsafe_b64encode(w8enc(txt)).decode("ascii") @@ -1102,8 +1150,8 @@ else: fsdec = w8dec -def s3enc(mem_cur, rd, fn): - ret = [] +def s3enc(mem_cur: "sqlite3.Cursor", rd: str, fn: str) -> tuple[str, str]: + ret: list[str] = [] for v in [rd, fn]: try: mem_cur.execute("select * from a where b = ?", (v,)) @@ -1112,10 +1160,10 @@ def s3enc(mem_cur, rd, fn): ret.append("//" + w8b64enc(v)) # self.log("mojien [{}] {}".format(v, ret[-1][2:])) - return tuple(ret) + return ret[0], ret[1] -def s3dec(rd, fn): +def s3dec(rd: str, fn: str) -> tuple[str, str]: ret = [] for v in [rd, fn]: if v.startswith("//"): @@ -1124,12 +1172,12 @@ def s3dec(rd, fn): else: ret.append(v) - return tuple(ret) + return ret[0], ret[1] -def atomic_move(src, dst): - src = fsenc(src) - dst = fsenc(dst) +def atomic_move(usrc: str, udst: str) -> None: + src = fsenc(usrc) + dst = fsenc(udst) if not PY2: os.replace(src, dst) else: @@ -1139,7 +1187,7 @@ def atomic_move(src, dst): os.rename(src, dst) -def read_socket(sr, total_size): +def read_socket(sr: Unrecv, total_size: int) -> Generator[bytes, None, None]: remains = total_size while remains > 0: bufsz = 32 * 1024 @@ -1149,14 +1197,14 @@ def read_socket(sr, total_size): try: buf = sr.recv(bufsz) except OSError: - m = "client d/c during binary post after {} bytes, {} bytes remaining" - raise Pebkac(400, m.format(total_size - remains, remains)) + t = "client d/c during binary post after {} bytes, {} bytes remaining" + raise Pebkac(400, t.format(total_size - remains, remains)) remains -= len(buf) yield buf -def read_socket_unbounded(sr): +def read_socket_unbounded(sr: Unrecv) -> Generator[bytes, None, None]: try: while True: yield sr.recv(32 * 1024) @@ -1164,7 +1212,9 @@ def read_socket_unbounded(sr): return -def read_socket_chunked(sr, log=None): +def read_socket_chunked( + sr: Unrecv, log: Optional[NamedLogger] = None +) -> Generator[bytes, None, None]: err = "upload aborted: expected chunk length, got [{}] |{}| instead" while True: buf = b"" @@ -1191,8 +1241,8 @@ def read_socket_chunked(sr, log=None): if x == b"\r\n": return - m = "protocol error after final chunk: want b'\\r\\n', got {!r}" - raise Pebkac(400, m.format(x)) + t = "protocol error after final chunk: want b'\\r\\n', got {!r}" + raise Pebkac(400, t.format(x)) if log: log("receiving {} byte chunk".format(chunklen)) @@ -1202,11 +1252,11 @@ def read_socket_chunked(sr, log=None): x = sr.recv_ex(2, False) if x != b"\r\n": - m = "protocol error in chunk separator: want b'\\r\\n', got {!r}" - raise Pebkac(400, m.format(x)) + t = "protocol error in chunk separator: want b'\\r\\n', got {!r}" + raise Pebkac(400, t.format(x)) -def yieldfile(fn): +def yieldfile(fn: str) -> Generator[bytes, None, None]: with open(fsenc(fn), "rb", 512 * 1024) as f: while True: buf = f.read(64 * 1024) @@ -1216,7 +1266,11 @@ def yieldfile(fn): yield buf -def hashcopy(fin, fout, slp=0): +def hashcopy( + fin: Union[typing.BinaryIO, Generator[bytes, None, None]], + fout: Union[typing.BinaryIO, typing.IO[Any]], + slp: int = 0, +) -> tuple[int, str, str]: hashobj = hashlib.sha512() tlen = 0 for buf in fin: @@ -1232,7 +1286,15 @@ def hashcopy(fin, fout, slp=0): return tlen, hashobj.hexdigest(), digest_b64 -def sendfile_py(log, lower, upper, f, s, bufsz, slp): +def sendfile_py( + log: NamedLogger, + lower: int, + upper: int, + f: typing.BinaryIO, + s: socket.socket, + bufsz: int, + slp: int, +) -> int: remains = upper - lower f.seek(lower) while remains > 0: @@ -1252,25 +1314,37 @@ def sendfile_py(log, lower, upper, f, s, bufsz, slp): return 0 -def sendfile_kern(log, lower, upper, f, s, bufsz, slp): +def sendfile_kern( + log: NamedLogger, + lower: int, + upper: int, + f: typing.BinaryIO, + s: socket.socket, + bufsz: int, + slp: int, +) -> int: out_fd = s.fileno() in_fd = f.fileno() ofs = lower - stuck = None + stuck = 0.0 while ofs < upper: stuck = stuck or time.time() try: req = min(2 ** 30, upper - ofs) select.select([], [out_fd], [], 10) n = os.sendfile(out_fd, in_fd, ofs, req) - stuck = None - except Exception as ex: + stuck = 0 + except OSError as ex: d = time.time() - stuck log("sendfile stuck for {:.3f} sec: {!r}".format(d, ex)) if d < 3600 and ex.errno == 11: # eagain continue n = 0 + except Exception as ex: + n = 0 + d = time.time() - stuck + log("sendfile failed after {:.3f} sec: {!r}".format(d, ex)) if n <= 0: return upper - ofs @@ -1281,7 +1355,9 @@ def sendfile_kern(log, lower, upper, f, s, bufsz, slp): return 0 -def statdir(logger, scandir, lstat, top): +def statdir( + logger: Optional[RootLogger], scandir: bool, lstat: bool, top: str +) -> Generator[tuple[str, os.stat_result], None, None]: if lstat and ANYWIN: lstat = False @@ -1295,30 +1371,42 @@ def statdir(logger, scandir, lstat, top): with os.scandir(btop) as dh: for fh in dh: try: - yield [fsdec(fh.name), fh.stat(follow_symlinks=not lstat)] + yield (fsdec(fh.name), fh.stat(follow_symlinks=not lstat)) except Exception as ex: + if not logger: + continue + logger(src, "[s] {} @ {}".format(repr(ex), fsdec(fh.path)), 6) else: src = "listdir" - fun = os.lstat if lstat else os.stat + fun: Any = os.lstat if lstat else os.stat for name in os.listdir(btop): abspath = os.path.join(btop, name) try: - yield [fsdec(name), fun(abspath)] + yield (fsdec(name), fun(abspath)) except Exception as ex: + if not logger: + continue + logger(src, "[s] {} @ {}".format(repr(ex), fsdec(abspath)), 6) except Exception as ex: - logger(src, "{} @ {}".format(repr(ex), top), 1) + t = "{} @ {}".format(repr(ex), top) + if logger: + logger(src, t, 1) + else: + print(t) -def rmdirs(logger, scandir, lstat, top, depth): +def rmdirs( + logger: RootLogger, scandir: bool, lstat: bool, top: str, depth: int +) -> tuple[list[str], list[str]]: if not os.path.exists(fsenc(top)) or not os.path.isdir(fsenc(top)): top = os.path.dirname(top) depth -= 1 - dirs = statdir(logger, scandir, lstat, top) - dirs = [x[0] for x in dirs if stat.S_ISDIR(x[1].st_mode)] + stats = statdir(logger, scandir, lstat, top) + dirs = [x[0] for x in stats if stat.S_ISDIR(x[1].st_mode)] dirs = [os.path.join(top, x) for x in dirs] ok = [] ng = [] @@ -1337,7 +1425,7 @@ def rmdirs(logger, scandir, lstat, top, depth): return ok, ng -def unescape_cookie(orig): +def unescape_cookie(orig: str) -> str: # mw=idk; doot=qwe%2Crty%3Basd+fgh%2Bjkl%25zxc%26vbn # qwe,rty;asd fgh+jkl%zxc&vbn ret = "" esc = "" @@ -1365,7 +1453,7 @@ def unescape_cookie(orig): return ret -def guess_mime(url, fallback="application/octet-stream"): +def guess_mime(url: str, fallback: str = "application/octet-stream") -> str: try: _, ext = url.rsplit(".", 1) except: @@ -1387,7 +1475,9 @@ def guess_mime(url, fallback="application/octet-stream"): return ret -def runcmd(argv, timeout=None, **ka): +def runcmd( + argv: Union[list[bytes], list[str]], timeout: Optional[int] = None, **ka: Any +) -> tuple[int, str, str]: p = sp.Popen(argv, stdout=sp.PIPE, stderr=sp.PIPE, **ka) if not timeout or PY2: stdout, stderr = p.communicate() @@ -1400,10 +1490,10 @@ def runcmd(argv, timeout=None, **ka): stdout = stdout.decode("utf-8", "replace") stderr = stderr.decode("utf-8", "replace") - return [p.returncode, stdout, stderr] + return p.returncode, stdout, stderr -def chkcmd(argv, **ka): +def chkcmd(argv: Union[list[bytes], list[str]], **ka: Any) -> tuple[str, str]: ok, sout, serr = runcmd(argv, **ka) if ok != 0: retchk(ok, argv, serr) @@ -1412,7 +1502,7 @@ def chkcmd(argv, **ka): return sout, serr -def mchkcmd(argv, timeout=10): +def mchkcmd(argv: Union[list[bytes], list[str]], timeout: int = 10) -> None: if PY2: with open(os.devnull, "wb") as f: rv = sp.call(argv, stdout=f, stderr=f) @@ -1423,7 +1513,14 @@ def mchkcmd(argv, timeout=10): raise sp.CalledProcessError(rv, (argv[0], b"...", argv[-1])) -def retchk(rc, cmd, serr, logger=None, color=None, verbose=False): +def retchk( + rc: int, + cmd: Union[list[bytes], list[str]], + serr: str, + logger: Optional[NamedLogger] = None, + color: Union[int, str] = 0, + verbose: bool = False, +) -> None: if rc < 0: rc = 128 - rc @@ -1446,33 +1543,33 @@ def retchk(rc, cmd, serr, logger=None, color=None, verbose=False): s = "invalid retcode" if s: - m = "{} <{}>".format(rc, s) + t = "{} <{}>".format(rc, s) else: - m = str(rc) + t = str(rc) try: - c = " ".join([fsdec(x) for x in cmd]) + c = " ".join([fsdec(x) for x in cmd]) # type: ignore except: c = str(cmd) - m = "error {} from [{}]".format(m, c) + t = "error {} from [{}]".format(t, c) if serr: - m += "\n" + serr + t += "\n" + serr if logger: - logger(m, color) + logger(t, color) else: - raise Exception(m) + raise Exception(t) -def gzip_orig_sz(fn): +def gzip_orig_sz(fn: str) -> int: with open(fsenc(fn), "rb") as f: f.seek(-4, 2) rv = f.read(4) - return sunpack(b"I", rv)[0] + return sunpack(b"I", rv)[0] # type: ignore -def py_desc(): +def py_desc() -> str: interp = platform.python_implementation() py_ver = ".".join([str(x) for x in sys.version_info]) ofs = py_ver.find(".final.") @@ -1487,15 +1584,15 @@ def py_desc(): host_os = platform.system() compiler = platform.python_compiler() - os_ver = re.search(r"([0-9]+\.[0-9\.]+)", platform.version()) - os_ver = os_ver.group(1) if os_ver else "" + m = re.search(r"([0-9]+\.[0-9\.]+)", platform.version()) + os_ver = m.group(1) if m else "" return "{:>9} v{} on {}{} {} [{}]".format( interp, py_ver, host_os, bitness, os_ver, compiler ) -def align_tab(lines): +def align_tab(lines: list[str]) -> list[str]: rows = [] ncols = 0 for ln in lines: @@ -1512,9 +1609,9 @@ def align_tab(lines): class Pebkac(Exception): - def __init__(self, code, msg=None): + def __init__(self, code: int, msg: Optional[str] = None) -> None: super(Pebkac, self).__init__(msg or HTTPCODE[code]) self.code = code - def __repr__(self): + def __repr__(self) -> str: return "Pebkac({}, {})".format(self.code, repr(self.args)) diff --git a/scripts/make-pypi-release.sh b/scripts/make-pypi-release.sh index 6a568051..7e723d4b 100755 --- a/scripts/make-pypi-release.sh +++ b/scripts/make-pypi-release.sh @@ -90,6 +90,15 @@ function have() { have setuptools have wheel have twine + +# remove type hints to support python < 3.9 +rm -rf build/pypi +mkdir -p build/pypi +cp -pR setup.py README.md LICENSE copyparty tests bin scripts/strip_hints build/pypi/ +cd build/pypi +tar --strip-components=2 -xf ../strip-hints-0.1.10.tar.gz strip-hints-0.1.10/src/strip_hints +python3 -c 'from strip_hints.a import uh; uh("copyparty")' + ./setup.py clean2 ./setup.py sdist bdist_wheel --universal diff --git a/scripts/make-sfx.sh b/scripts/make-sfx.sh index 8f9856f3..7402c517 100755 --- a/scripts/make-sfx.sh +++ b/scripts/make-sfx.sh @@ -76,7 +76,7 @@ while [ ! -z "$1" ]; do no-hl) no_hl=1 ; ;; no-dd) no_dd=1 ; ;; no-cm) no_cm=1 ; ;; - fast) zopf=100 ; ;; + fast) zopf= ; ;; lang) shift;langs="$1"; ;; *) help ; ;; esac @@ -106,7 +106,7 @@ tmpdir="$( [ $repack ] && { old="$tmpdir/pe-copyparty" echo "repack of files in $old" - cp -pR "$old/"*{dep-j2,dep-ftp,copyparty} . + cp -pR "$old/"*{j2,ftp,copyparty} . } [ $repack ] || { @@ -130,8 +130,8 @@ tmpdir="$( mv MarkupSafe-*/src/markupsafe . rm -rf MarkupSafe-* markupsafe/_speedups.c - mkdir dep-j2/ - mv {markupsafe,jinja2} dep-j2/ + mkdir j2/ + mv {markupsafe,jinja2} j2/ echo collecting pyftpdlib f="../build/pyftpdlib-1.5.6.tar.gz" @@ -143,8 +143,8 @@ tmpdir="$( mv pyftpdlib-release-*/pyftpdlib . rm -rf pyftpdlib-release-* pyftpdlib/test - mkdir dep-ftp/ - mv pyftpdlib dep-ftp/ + mkdir ftp/ + mv pyftpdlib ftp/ echo collecting asyncore, asynchat for n in asyncore.py asynchat.py; do @@ -154,6 +154,24 @@ tmpdir="$( wget -O$f "$url" || curl -L "$url" >$f) done + # enable this to dynamically remove type hints at startup, + # in case a future python version can use them for performance + true || ( + echo collecting strip-hints + f=../build/strip-hints-0.1.10.tar.gz + [ -e $f ] || + (url=https://files.pythonhosted.org/packages/9c/d4/312ddce71ee10f7e0ab762afc027e07a918f1c0e1be5b0069db5b0e7542d/strip-hints-0.1.10.tar.gz; + wget -O$f "$url" || curl -L "$url" >$f) + + tar -zxf $f + mv strip-hints-0.1.10/src/strip_hints . + rm -rf strip-hints-* strip_hints/import_hooks* + sed -ri 's/[a-z].* as import_hooks$/"""a"""/' strip_hints/*.py + + cp -pR ../scripts/strip_hints/ . + ) + cp -pR ../scripts/py2/ . + # msys2 tar is bad, make the best of it echo collecting source [ $clean ] && { @@ -170,6 +188,9 @@ tmpdir="$( for n in asyncore.py asynchat.py; do awk 'NR<4||NR>27;NR==4{print"# license: https://opensource.org/licenses/ISC\n"}' ../build/$n >copyparty/vend/$n done + + # remove type hints before build instead + (cd copyparty; python3 ../../scripts/strip_hints/a.py; rm uh) } ver= @@ -274,17 +295,23 @@ rm have tmv "$f" done -[ $repack ] || -find | grep -E '\.py$' | - grep -vE '__version__' | - tr '\n' '\0' | - xargs -0 "$pybin" ../scripts/uncomment.py +[ $repack ] || { + # uncomment + find | grep -E '\.py$' | + grep -vE '__version__' | + tr '\n' '\0' | + xargs -0 "$pybin" ../scripts/uncomment.py -f=dep-j2/jinja2/constants.py + # py2-compat + #find | grep -E '\.py$' | while IFS= read -r x; do + # sed -ri '/: TypeAlias = /d' "$x"; done +} + +f=j2/jinja2/constants.py awk '/^LOREM_IPSUM_WORDS/{o=1;print "LOREM_IPSUM_WORDS = u\"a\"";next} !o; /"""/{o=0}' <$f >t tmv "$f" -grep -rLE '^#[^a-z]*coding: utf-8' dep-j2 | +grep -rLE '^#[^a-z]*coding: utf-8' j2 | while IFS= read -r f; do (echo "# coding: utf-8"; cat "$f") >t tmv "$f" @@ -313,7 +340,7 @@ find | grep -E '\.(js|html)$' | while IFS= read -r f; do done gzres() { - command -v pigz && + command -v pigz && [ $zopf ] && pk="pigz -11 -I $zopf" || pk='gzip' @@ -354,7 +381,8 @@ nf=$(ls -1 "$zdir"/arc.* | wc -l) } [ $use_zdir ] && { arcs=("$zdir"/arc.*) - arc="${arcs[$RANDOM % ${#arcs[@]} ] }" + n=$(( $RANDOM % ${#arcs[@]} )) + arc="${arcs[n]}" echo "using $arc" tar -xf "$arc" for f in copyparty/web/*.gz; do @@ -364,7 +392,7 @@ nf=$(ls -1 "$zdir"/arc.* | wc -l) echo gen tarlist -for d in copyparty dep-j2 dep-ftp; do find $d -type f; done | +for d in copyparty j2 ftp py2; do find $d -type f; done | # strip_hints sed -r 's/(.*)\.(.*)/\2 \1/' | LC_ALL=C sort | sed -r 's/([^ ]*) (.*)/\2.\1/' | grep -vE '/list1?$' > list1 diff --git a/scripts/run-tests.sh b/scripts/run-tests.sh index 48aaf075..1977a55a 100755 --- a/scripts/run-tests.sh +++ b/scripts/run-tests.sh @@ -1,13 +1,23 @@ #!/bin/bash set -ex +rm -rf unt +mkdir -p unt/srv +cp -pR copyparty tests unt/ +cd unt +python3 ../scripts/strip_hints/a.py + pids=() for py in python{2,3}; do + PYTHONPATH= + [ $py = python2 ] && PYTHONPATH=../scripts/py2 + export PYTHONPATH + nice $py -m unittest discover -s tests >/dev/null & pids+=($!) done -python3 scripts/test/smoketest.py & +python3 ../scripts/test/smoketest.py & pids+=($!) for pid in ${pids[@]}; do diff --git a/scripts/sfx.py b/scripts/sfx.py index 4723a0da..2cb20c8a 100644 --- a/scripts/sfx.py +++ b/scripts/sfx.py @@ -379,9 +379,20 @@ def run(tmp, j2, ftp): t.daemon = True t.start() - ld = (("", ""), (j2, "dep-j2"), (ftp, "dep-ftp")) + ld = (("", ""), (j2, "j2"), (ftp, "ftp"), (not PY2, "py2")) ld = [os.path.join(tmp, b) for a, b in ld if not a] + # skip 1 + # enable this to dynamically remove type hints at startup, + # in case a future python version can use them for performance + if sys.version_info < (3, 10) and False: + sys.path.insert(0, ld[0]) + + from strip_hints.a import uh + + uh(tmp + "/copyparty") + # skip 0 + if any([re.match(r"^-.*j[0-9]", x) for x in sys.argv]): run_s(ld) else: diff --git a/scripts/sfx.sh b/scripts/sfx.sh index 1f53c8db..1496f970 100644 --- a/scripts/sfx.sh +++ b/scripts/sfx.sh @@ -47,7 +47,7 @@ grep -E '/(python|pypy)[0-9\.-]*$' >$dir/pys || true printf '\033[1;30mlooking for jinja2 in [%s]\033[0m\n' "$_py" >&2 $_py -c 'import jinja2' 2>/dev/null || continue printf '%s\n' "$_py" - mv $dir/{,x.}dep-j2 + mv $dir/{,x.}j2 break done)" diff --git a/scripts/strip_hints/a.py b/scripts/strip_hints/a.py new file mode 100644 index 00000000..06bf4f6b --- /dev/null +++ b/scripts/strip_hints/a.py @@ -0,0 +1,57 @@ +# coding: utf-8 +from __future__ import print_function, unicode_literals + +import re +import os +import sys +from strip_hints import strip_file_to_string + + +# list unique types used in hints: +# rm -rf unt && cp -pR copyparty unt && (cd unt && python3 ../scripts/strip_hints/a.py) +# diff -wNarU1 copyparty unt | grep -E '^\-' | sed -r 's/[^][, ]+://g; s/[^][, ]+[[(]//g; s/[],()<>{} -]/\n/g' | grep -E .. | sort | uniq -c | sort -n + + +def pr(m): + sys.stderr.write(m) + sys.stderr.flush() + + +def uh(top): + if os.path.exists(top + "/uh"): + return + + libs = "typing|types|collections\.abc" + ptn = re.compile(r"^(\s*)(from (?:{0}) import |import (?:{0})\b).*".format(libs)) + + # pr("building support for your python ver") + pr("unhinting") + for (dp, _, fns) in os.walk(top): + for fn in fns: + if not fn.endswith(".py"): + continue + + pr(".") + fp = os.path.join(dp, fn) + cs = strip_file_to_string(fp, no_ast=True, to_empty=True) + + # remove expensive imports too + lns = [] + for ln in cs.split("\n"): + m = ptn.match(ln) + if m: + ln = m.group(1) + "raise Exception()" + + lns.append(ln) + + cs = "\n".join(lns) + with open(fp, "wb") as f: + f.write(cs.encode("utf-8")) + + pr("k\n\n") + with open(top + "/uh", "wb") as f: + f.write(b"a") + + +if __name__ == "__main__": + uh(".") diff --git a/scripts/test/race.py b/scripts/test/race.py index 09922e60..77ef0e46 100644 --- a/scripts/test/race.py +++ b/scripts/test/race.py @@ -58,13 +58,13 @@ class CState(threading.Thread): remotes.append("?") remotes_ok = False - m = [] + ta = [] for conn, remote in zip(self.cs, remotes): stage = len(conn.st) - m.append(f"\033[3{colors[stage]}m{remote}") + ta.append(f"\033[3{colors[stage]}m{remote}") - m = " ".join(m) - print(f"{m}\033[0m\n\033[A", end="") + t = " ".join(ta) + print(f"{t}\033[0m\n\033[A", end="") def allget(cs, urls): diff --git a/scripts/test/smoketest.py b/scripts/test/smoketest.py index 8508f127..a37fc44a 100644 --- a/scripts/test/smoketest.py +++ b/scripts/test/smoketest.py @@ -72,6 +72,8 @@ def tc1(vflags): for _ in range(10): try: os.mkdir(td) + if os.path.exists(td): + break except: time.sleep(0.1) # win10 diff --git a/tests/test_vfs.py b/tests/test_vfs.py index cb7f08fe..4818d180 100644 --- a/tests/test_vfs.py +++ b/tests/test_vfs.py @@ -85,7 +85,7 @@ class TestVFS(unittest.TestCase): pass def assertAxs(self, dct, lst): - t1 = list(sorted(dct.keys())) + t1 = list(sorted(dct)) t2 = list(sorted(lst)) self.assertEqual(t1, t2) @@ -208,10 +208,10 @@ class TestVFS(unittest.TestCase): self.assertEqual(n.realpath, os.path.join(td, "a")) self.assertAxs(n.axs.uread, ["*"]) self.assertAxs(n.axs.uwrite, []) - self.assertEqual(vfs.can_access("/", "*"), [False, False, False, False, False]) - self.assertEqual(vfs.can_access("/", "k"), [True, True, False, False, False]) - self.assertEqual(vfs.can_access("/a", "*"), [True, False, False, False, False]) - self.assertEqual(vfs.can_access("/a", "k"), [True, False, False, False, False]) + self.assertEqual(vfs.can_access("/", "*"), (False, False, False, False, False)) + self.assertEqual(vfs.can_access("/", "k"), (True, True, False, False, False)) + self.assertEqual(vfs.can_access("/a", "*"), (True, False, False, False, False)) + self.assertEqual(vfs.can_access("/a", "k"), (True, False, False, False, False)) # breadth-first construction vfs = AuthSrv( @@ -279,7 +279,7 @@ class TestVFS(unittest.TestCase): n = au.vfs # root was not defined, so PWD with no access to anyone self.assertEqual(n.vpath, "") - self.assertEqual(n.realpath, None) + self.assertEqual(n.realpath, "") self.assertAxs(n.axs.uread, []) self.assertAxs(n.axs.uwrite, []) self.assertEqual(len(n.nodes), 1) diff --git a/tests/util.py b/tests/util.py index 55ec58ea..92ecaa88 100644 --- a/tests/util.py +++ b/tests/util.py @@ -90,7 +90,10 @@ def get_ramdisk(): class NullBroker(object): - def put(*args): + def say(*args): + pass + + def ask(*args): pass