mirror of
https://github.com/9001/copyparty.git
synced 2025-08-18 01:22:13 -06:00
url-param / header `ck` specifies hashing algo; md5 sha1 sha256 sha512 b2 blake2 b2s blake2s value 'no' or blank disables checksumming, for when copyparty is running on ancient gear and you don't really care about file integrity
3880 lines
101 KiB
Python
3880 lines
101 KiB
Python
# coding: utf-8
|
||
from __future__ import print_function, unicode_literals
|
||
|
||
import argparse
|
||
import base64
|
||
import binascii
|
||
import codecs
|
||
import errno
|
||
import hashlib
|
||
import hmac
|
||
import json
|
||
import logging
|
||
import math
|
||
import mimetypes
|
||
import os
|
||
import platform
|
||
import re
|
||
import select
|
||
import shutil
|
||
import signal
|
||
import socket
|
||
import stat
|
||
import struct
|
||
import subprocess as sp # nosec
|
||
import sys
|
||
import threading
|
||
import time
|
||
import traceback
|
||
from collections import Counter
|
||
|
||
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network
|
||
from queue import Queue
|
||
|
||
from .__init__ import (
|
||
ANYWIN,
|
||
EXE,
|
||
MACOS,
|
||
PY2,
|
||
PY36,
|
||
TYPE_CHECKING,
|
||
VT100,
|
||
WINDOWS,
|
||
EnvParams,
|
||
)
|
||
from .__version__ import S_BUILD_DT, S_VERSION
|
||
from .stolen import surrogateescape
|
||
|
||
try:
|
||
from datetime import datetime, timezone
|
||
|
||
UTC = timezone.utc
|
||
except:
|
||
from datetime import datetime, timedelta, tzinfo
|
||
|
||
TD_ZERO = timedelta(0)
|
||
|
||
class _UTC(tzinfo):
|
||
def utcoffset(self, dt):
|
||
return TD_ZERO
|
||
|
||
def tzname(self, dt):
|
||
return "UTC"
|
||
|
||
def dst(self, dt):
|
||
return TD_ZERO
|
||
|
||
UTC = _UTC()
|
||
|
||
|
||
if PY2:
|
||
range = xrange # type: ignore
|
||
|
||
|
||
if sys.version_info >= (3, 7) or (
|
||
PY36 and platform.python_implementation() == "CPython"
|
||
):
|
||
ODict = dict
|
||
else:
|
||
from collections import OrderedDict as ODict
|
||
|
||
|
||
def _ens(want: str) -> tuple[int, ...]:
|
||
ret: list[int] = []
|
||
for v in want.split():
|
||
try:
|
||
ret.append(getattr(errno, v))
|
||
except:
|
||
pass
|
||
|
||
return tuple(ret)
|
||
|
||
|
||
# WSAECONNRESET - foribly closed by remote
|
||
# WSAENOTSOCK - no longer a socket
|
||
# EUNATCH - can't assign requested address (wifi down)
|
||
E_SCK = _ens("ENOTCONN EUNATCH EBADF WSAENOTSOCK WSAECONNRESET")
|
||
E_ADDR_NOT_AVAIL = _ens("EADDRNOTAVAIL WSAEADDRNOTAVAIL")
|
||
E_ADDR_IN_USE = _ens("EADDRINUSE WSAEADDRINUSE")
|
||
E_ACCESS = _ens("EACCES WSAEACCES")
|
||
E_UNREACH = _ens("EHOSTUNREACH WSAEHOSTUNREACH ENETUNREACH WSAENETUNREACH")
|
||
|
||
IP6ALL = "0:0:0:0:0:0:0:0"
|
||
|
||
|
||
try:
|
||
import ctypes
|
||
import fcntl
|
||
import termios
|
||
except:
|
||
pass
|
||
|
||
try:
|
||
if os.environ.get("PRTY_NO_SQLITE"):
|
||
raise Exception()
|
||
|
||
HAVE_SQLITE3 = True
|
||
import sqlite3
|
||
|
||
assert hasattr(sqlite3, "connect") # graalpy
|
||
except:
|
||
HAVE_SQLITE3 = False
|
||
|
||
try:
|
||
if os.environ.get("PRTY_NO_PSUTIL"):
|
||
raise Exception()
|
||
|
||
HAVE_PSUTIL = True
|
||
import psutil
|
||
except:
|
||
HAVE_PSUTIL = False
|
||
|
||
if True: # pylint: disable=using-constant-test
|
||
import types
|
||
from collections.abc import Callable, Iterable
|
||
|
||
import typing
|
||
from typing import IO, Any, Generator, Optional, Pattern, Protocol, Union
|
||
|
||
try:
|
||
from typing import LiteralString
|
||
except:
|
||
pass
|
||
|
||
class RootLogger(Protocol):
|
||
def __call__(self, src: str, msg: str, c: Union[int, str] = 0) -> None:
|
||
return None
|
||
|
||
class NamedLogger(Protocol):
|
||
def __call__(self, msg: str, c: Union[int, str] = 0) -> None:
|
||
return None
|
||
|
||
|
||
if TYPE_CHECKING:
|
||
import magic
|
||
|
||
from .authsrv import VFS
|
||
from .broker_util import BrokerCli
|
||
from .up2k import Up2k
|
||
|
||
FAKE_MP = False
|
||
|
||
try:
|
||
if os.environ.get("PRTY_NO_MP"):
|
||
raise ImportError()
|
||
|
||
import multiprocessing as mp
|
||
|
||
# import multiprocessing.dummy as mp
|
||
except ImportError:
|
||
# support jython
|
||
mp = None # type: ignore
|
||
|
||
if not PY2:
|
||
from io import BytesIO
|
||
else:
|
||
from StringIO import StringIO as BytesIO # type: ignore
|
||
|
||
|
||
try:
|
||
if os.environ.get("PRTY_NO_IPV6"):
|
||
raise Exception()
|
||
|
||
socket.inet_pton(socket.AF_INET6, "::1")
|
||
HAVE_IPV6 = True
|
||
except:
|
||
|
||
def inet_pton(fam, ip):
|
||
return socket.inet_aton(ip)
|
||
|
||
socket.inet_pton = inet_pton
|
||
HAVE_IPV6 = False
|
||
|
||
|
||
try:
|
||
struct.unpack(b">i", b"idgi")
|
||
spack = struct.pack # type: ignore
|
||
sunpack = struct.unpack # type: ignore
|
||
except:
|
||
|
||
def spack(fmt: bytes, *a: Any) -> bytes:
|
||
return struct.pack(fmt.decode("ascii"), *a)
|
||
|
||
def sunpack(fmt: bytes, a: bytes) -> tuple[Any, ...]:
|
||
return struct.unpack(fmt.decode("ascii"), a)
|
||
|
||
|
||
try:
|
||
BITNESS = struct.calcsize(b"P") * 8
|
||
except:
|
||
BITNESS = struct.calcsize("P") * 8
|
||
|
||
|
||
ansi_re = re.compile("\033\\[[^mK]*[mK]")
|
||
|
||
|
||
BOS_SEP = ("%s" % (os.sep,)).encode("ascii")
|
||
|
||
|
||
surrogateescape.register_surrogateescape()
|
||
if WINDOWS and PY2:
|
||
FS_ENCODING = "utf-8"
|
||
else:
|
||
FS_ENCODING = sys.getfilesystemencoding()
|
||
|
||
|
||
SYMTIME = PY36 and os.utime in os.supports_follow_symlinks
|
||
|
||
META_NOBOTS = '<meta name="robots" content="noindex, nofollow">\n'
|
||
|
||
FFMPEG_URL = "https://www.gyan.dev/ffmpeg/builds/ffmpeg-git-full.7z"
|
||
|
||
HTTPCODE = {
|
||
200: "OK",
|
||
201: "Created",
|
||
204: "No Content",
|
||
206: "Partial Content",
|
||
207: "Multi-Status",
|
||
301: "Moved Permanently",
|
||
302: "Found",
|
||
304: "Not Modified",
|
||
400: "Bad Request",
|
||
401: "Unauthorized",
|
||
403: "Forbidden",
|
||
404: "Not Found",
|
||
405: "Method Not Allowed",
|
||
409: "Conflict",
|
||
411: "Length Required",
|
||
412: "Precondition Failed",
|
||
413: "Payload Too Large",
|
||
416: "Requested Range Not Satisfiable",
|
||
422: "Unprocessable Entity",
|
||
423: "Locked",
|
||
429: "Too Many Requests",
|
||
500: "Internal Server Error",
|
||
501: "Not Implemented",
|
||
503: "Service Unavailable",
|
||
999: "MissingNo",
|
||
}
|
||
|
||
|
||
IMPLICATIONS = [
|
||
["e2dsa", "e2ds"],
|
||
["e2ds", "e2d"],
|
||
["e2tsr", "e2ts"],
|
||
["e2ts", "e2t"],
|
||
["e2t", "e2d"],
|
||
["e2vu", "e2v"],
|
||
["e2vp", "e2v"],
|
||
["e2v", "e2d"],
|
||
["hardlink_only", "hardlink"],
|
||
["hardlink", "dedup"],
|
||
["tftpvv", "tftpv"],
|
||
["smbw", "smb"],
|
||
["smb1", "smb"],
|
||
["smbvvv", "smbvv"],
|
||
["smbvv", "smbv"],
|
||
["smbv", "smb"],
|
||
["zv", "zmv"],
|
||
["zv", "zsv"],
|
||
["z", "zm"],
|
||
["z", "zs"],
|
||
["zmvv", "zmv"],
|
||
["zm4", "zm"],
|
||
["zm6", "zm"],
|
||
["zmv", "zm"],
|
||
["zms", "zm"],
|
||
["zsv", "zs"],
|
||
]
|
||
if ANYWIN:
|
||
IMPLICATIONS.extend([["z", "zm4"]])
|
||
|
||
|
||
UNPLICATIONS = [["no_dav", "daw"]]
|
||
|
||
|
||
DAV_ALLPROP_L = [
|
||
"contentclass",
|
||
"creationdate",
|
||
"defaultdocument",
|
||
"displayname",
|
||
"getcontentlanguage",
|
||
"getcontentlength",
|
||
"getcontenttype",
|
||
"getlastmodified",
|
||
"href",
|
||
"iscollection",
|
||
"ishidden",
|
||
"isreadonly",
|
||
"isroot",
|
||
"isstructureddocument",
|
||
"lastaccessed",
|
||
"name",
|
||
"parentname",
|
||
"resourcetype",
|
||
"supportedlock",
|
||
]
|
||
DAV_ALLPROPS = set(DAV_ALLPROP_L)
|
||
|
||
|
||
MIMES = {
|
||
"opus": "audio/ogg; codecs=opus",
|
||
}
|
||
|
||
|
||
def _add_mimes() -> None:
|
||
# `mimetypes` is woefully unpopulated on windows
|
||
# but will be used as fallback on linux
|
||
|
||
for ln in """text css html csv
|
||
application json wasm xml pdf rtf zip jar fits wasm
|
||
image webp jpeg png gif bmp jxl jp2 jxs jxr tiff bpg heic heif avif
|
||
audio aac ogg wav flac ape amr
|
||
video webm mp4 mpeg
|
||
font woff woff2 otf ttf
|
||
""".splitlines():
|
||
k, vs = ln.split(" ", 1)
|
||
for v in vs.strip().split():
|
||
MIMES[v] = "{}/{}".format(k, v)
|
||
|
||
for ln in """text md=plain txt=plain js=javascript
|
||
application 7z=x-7z-compressed tar=x-tar bz2=x-bzip2 gz=gzip rar=x-rar-compressed zst=zstd xz=x-xz lz=lzip cpio=x-cpio
|
||
application msi=x-ms-installer cab=vnd.ms-cab-compressed rpm=x-rpm crx=x-chrome-extension
|
||
application epub=epub+zip mobi=x-mobipocket-ebook lit=x-ms-reader rss=rss+xml atom=atom+xml torrent=x-bittorrent
|
||
application p7s=pkcs7-signature dcm=dicom shx=vnd.shx shp=vnd.shp dbf=x-dbf gml=gml+xml gpx=gpx+xml amf=x-amf
|
||
application swf=x-shockwave-flash m3u=vnd.apple.mpegurl db3=vnd.sqlite3 sqlite=vnd.sqlite3
|
||
text ass=plain ssa=plain
|
||
image jpg=jpeg xpm=x-xpixmap psd=vnd.adobe.photoshop jpf=jpx tif=tiff ico=x-icon djvu=vnd.djvu
|
||
image heic=heic-sequence heif=heif-sequence hdr=vnd.radiance svg=svg+xml
|
||
audio caf=x-caf mp3=mpeg m4a=mp4 mid=midi mpc=musepack aif=aiff au=basic qcp=qcelp
|
||
video mkv=x-matroska mov=quicktime avi=x-msvideo m4v=x-m4v ts=mp2t
|
||
video asf=x-ms-asf flv=x-flv 3gp=3gpp 3g2=3gpp2 rmvb=vnd.rn-realmedia-vbr
|
||
font ttc=collection
|
||
""".splitlines():
|
||
k, ems = ln.split(" ", 1)
|
||
for em in ems.strip().split():
|
||
ext, mime = em.split("=")
|
||
MIMES[ext] = "{}/{}".format(k, mime)
|
||
|
||
|
||
_add_mimes()
|
||
|
||
|
||
EXTS: dict[str, str] = {v: k for k, v in MIMES.items()}
|
||
|
||
EXTS["vnd.mozilla.apng"] = "png"
|
||
|
||
MAGIC_MAP = {"jpeg": "jpg"}
|
||
|
||
|
||
DEF_EXP = "self.ip self.ua self.uname self.host cfg.name cfg.logout vf.scan vf.thsize hdr.cf_ipcountry srv.itime srv.htime"
|
||
|
||
DEF_MTE = ".files,circle,album,.tn,artist,title,.bpm,key,.dur,.q,.vq,.aq,vc,ac,fmt,res,.fps,ahash,vhash"
|
||
|
||
DEF_MTH = ".vq,.aq,vc,ac,fmt,res,.fps"
|
||
|
||
|
||
REKOBO_KEY = {
|
||
v: ln.split(" ", 1)[0]
|
||
for ln in """
|
||
1B 6d B
|
||
2B 7d Gb F#
|
||
3B 8d Db C#
|
||
4B 9d Ab G#
|
||
5B 10d Eb D#
|
||
6B 11d Bb A#
|
||
7B 12d F
|
||
8B 1d C
|
||
9B 2d G
|
||
10B 3d D
|
||
11B 4d A
|
||
12B 5d E
|
||
1A 6m Abm G#m
|
||
2A 7m Ebm D#m
|
||
3A 8m Bbm A#m
|
||
4A 9m Fm
|
||
5A 10m Cm
|
||
6A 11m Gm
|
||
7A 12m Dm
|
||
8A 1m Am
|
||
9A 2m Em
|
||
10A 3m Bm
|
||
11A 4m Gbm F#m
|
||
12A 5m Dbm C#m
|
||
""".strip().split(
|
||
"\n"
|
||
)
|
||
for v in ln.strip().split(" ")[1:]
|
||
if v
|
||
}
|
||
|
||
REKOBO_LKEY = {k.lower(): v for k, v in REKOBO_KEY.items()}
|
||
|
||
|
||
_exestr = "python3 python ffmpeg ffprobe cfssl cfssljson cfssl-certinfo"
|
||
CMD_EXEB = set(_exestr.encode("utf-8").split())
|
||
CMD_EXES = set(_exestr.split())
|
||
|
||
|
||
# mostly from https://github.com/github/gitignore/blob/main/Global/macOS.gitignore
|
||
APPLESAN_TXT = r"/(__MACOS|Icon\r\r)|/\.(_|DS_Store|AppleDouble|LSOverride|DocumentRevisions-|fseventsd|Spotlight-|TemporaryItems|Trashes|VolumeIcon\.icns|com\.apple\.timemachine\.donotpresent|AppleDB|AppleDesktop|apdisk)"
|
||
APPLESAN_RE = re.compile(APPLESAN_TXT)
|
||
|
||
|
||
HUMANSIZE_UNITS = ("B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB")
|
||
|
||
UNHUMANIZE_UNITS = {
|
||
"b": 1,
|
||
"k": 1024,
|
||
"m": 1024 * 1024,
|
||
"g": 1024 * 1024 * 1024,
|
||
"t": 1024 * 1024 * 1024 * 1024,
|
||
"p": 1024 * 1024 * 1024 * 1024 * 1024,
|
||
"e": 1024 * 1024 * 1024 * 1024 * 1024 * 1024,
|
||
}
|
||
|
||
VF_CAREFUL = {"mv_re_t": 5, "rm_re_t": 5, "mv_re_r": 0.1, "rm_re_r": 0.1}
|
||
|
||
|
||
def read_ram() -> tuple[float, float]:
|
||
a = b = 0
|
||
try:
|
||
with open("/proc/meminfo", "rb", 0x10000) as f:
|
||
zsl = f.read(0x10000).decode("ascii", "replace").split("\n")
|
||
|
||
p = re.compile("^MemTotal:.* kB")
|
||
zs = next((x for x in zsl if p.match(x)))
|
||
a = int((int(zs.split()[1]) / 0x100000) * 100) / 100
|
||
|
||
p = re.compile("^MemAvailable:.* kB")
|
||
zs = next((x for x in zsl if p.match(x)))
|
||
b = int((int(zs.split()[1]) / 0x100000) * 100) / 100
|
||
except:
|
||
pass
|
||
return a, b
|
||
|
||
|
||
RAM_TOTAL, RAM_AVAIL = read_ram()
|
||
|
||
|
||
pybin = sys.executable or ""
|
||
if EXE:
|
||
pybin = ""
|
||
for zsg in "python3 python".split():
|
||
try:
|
||
if ANYWIN:
|
||
zsg += ".exe"
|
||
|
||
zsg = shutil.which(zsg)
|
||
if zsg:
|
||
pybin = zsg
|
||
break
|
||
except:
|
||
pass
|
||
|
||
|
||
def py_desc() -> str:
|
||
interp = platform.python_implementation()
|
||
py_ver = ".".join([str(x) for x in sys.version_info])
|
||
ofs = py_ver.find(".final.")
|
||
if ofs > 0:
|
||
py_ver = py_ver[:ofs]
|
||
|
||
host_os = platform.system()
|
||
compiler = platform.python_compiler().split("http")[0]
|
||
|
||
m = re.search(r"([0-9]+\.[0-9\.]+)", platform.version())
|
||
os_ver = m.group(1) if m else ""
|
||
|
||
return "{:>9} v{} on {}{} {} [{}]".format(
|
||
interp, py_ver, host_os, BITNESS, os_ver, compiler
|
||
)
|
||
|
||
|
||
def _sqlite_ver() -> str:
|
||
assert sqlite3 # type: ignore # !rm
|
||
try:
|
||
co = sqlite3.connect(":memory:")
|
||
cur = co.cursor()
|
||
try:
|
||
vs = cur.execute("select * from pragma_compile_options").fetchall()
|
||
except:
|
||
vs = cur.execute("pragma compile_options").fetchall()
|
||
|
||
v = next(x[0].split("=")[1] for x in vs if x[0].startswith("THREADSAFE="))
|
||
cur.close()
|
||
co.close()
|
||
except:
|
||
v = "W"
|
||
|
||
return "{}*{}".format(sqlite3.sqlite_version, v)
|
||
|
||
|
||
try:
|
||
SQLITE_VER = _sqlite_ver()
|
||
except:
|
||
SQLITE_VER = "(None)"
|
||
|
||
try:
|
||
from jinja2 import __version__ as JINJA_VER
|
||
except:
|
||
JINJA_VER = "(None)"
|
||
|
||
try:
|
||
from pyftpdlib.__init__ import __ver__ as PYFTPD_VER
|
||
except:
|
||
PYFTPD_VER = "(None)"
|
||
|
||
try:
|
||
from partftpy.__init__ import __version__ as PARTFTPY_VER
|
||
except:
|
||
PARTFTPY_VER = "(None)"
|
||
|
||
|
||
PY_DESC = py_desc()
|
||
|
||
VERSIONS = (
|
||
"copyparty v{} ({})\n{}\n sqlite {} | jinja {} | pyftpd {} | tftp {}".format(
|
||
S_VERSION, S_BUILD_DT, PY_DESC, SQLITE_VER, JINJA_VER, PYFTPD_VER, PARTFTPY_VER
|
||
)
|
||
)
|
||
|
||
|
||
try:
|
||
_b64_enc_tl = bytes.maketrans(b"+/", b"-_")
|
||
_b64_dec_tl = bytes.maketrans(b"-_", b"+/")
|
||
|
||
def ub64enc(bs: bytes) -> bytes:
|
||
x = binascii.b2a_base64(bs, newline=False)
|
||
return x.translate(_b64_enc_tl)
|
||
|
||
def ub64dec(bs: bytes) -> bytes:
|
||
bs = bs.translate(_b64_dec_tl)
|
||
return binascii.a2b_base64(bs)
|
||
|
||
def b64enc(bs: bytes) -> bytes:
|
||
return binascii.b2a_base64(bs, newline=False)
|
||
|
||
def b64dec(bs: bytes) -> bytes:
|
||
return binascii.a2b_base64(bs)
|
||
|
||
zb = b">>>????"
|
||
zb2 = base64.urlsafe_b64encode(zb)
|
||
if zb2 != ub64enc(zb) or zb != ub64dec(zb2):
|
||
raise Exception("bad smoke")
|
||
|
||
except Exception as ex:
|
||
ub64enc = base64.urlsafe_b64encode # type: ignore
|
||
ub64dec = base64.urlsafe_b64decode # type: ignore
|
||
b64enc = base64.b64encode # type: ignore
|
||
b64dec = base64.b64decode # type: ignore
|
||
if not PY36:
|
||
print("using fallback base64 codec due to %r" % (ex,))
|
||
|
||
|
||
class Daemon(threading.Thread):
|
||
def __init__(
|
||
self,
|
||
target: Any,
|
||
name: Optional[str] = None,
|
||
a: Optional[Iterable[Any]] = None,
|
||
r: bool = True,
|
||
ka: Optional[dict[Any, Any]] = None,
|
||
) -> None:
|
||
threading.Thread.__init__(self, name=name)
|
||
self.a = a or ()
|
||
self.ka = ka or {}
|
||
self.fun = target
|
||
self.daemon = True
|
||
if r:
|
||
self.start()
|
||
|
||
def run(self):
|
||
if not ANYWIN and not PY2:
|
||
signal.pthread_sigmask(
|
||
signal.SIG_BLOCK, [signal.SIGINT, signal.SIGTERM, signal.SIGUSR1]
|
||
)
|
||
|
||
self.fun(*self.a, **self.ka)
|
||
|
||
|
||
class Netdev(object):
|
||
def __init__(self, ip: str, idx: int, name: str, desc: str):
|
||
self.ip = ip
|
||
self.idx = idx
|
||
self.name = name
|
||
self.desc = desc
|
||
|
||
def __str__(self):
|
||
return "{}-{}{}".format(self.idx, self.name, self.desc)
|
||
|
||
def __repr__(self):
|
||
return "'{}-{}'".format(self.idx, self.name)
|
||
|
||
def __lt__(self, rhs):
|
||
return str(self) < str(rhs)
|
||
|
||
def __eq__(self, rhs):
|
||
return str(self) == str(rhs)
|
||
|
||
|
||
class Cooldown(object):
|
||
def __init__(self, maxage: float) -> None:
|
||
self.maxage = maxage
|
||
self.mutex = threading.Lock()
|
||
self.hist: dict[str, float] = {}
|
||
self.oldest = 0.0
|
||
|
||
def poke(self, key: str) -> bool:
|
||
with self.mutex:
|
||
now = time.time()
|
||
|
||
ret = False
|
||
pv: float = self.hist.get(key, 0)
|
||
if now - pv > self.maxage:
|
||
self.hist[key] = now
|
||
ret = True
|
||
|
||
if self.oldest - now > self.maxage * 2:
|
||
self.hist = {
|
||
k: v for k, v in self.hist.items() if now - v < self.maxage
|
||
}
|
||
self.oldest = sorted(self.hist.values())[0]
|
||
|
||
return ret
|
||
|
||
|
||
class HLog(logging.Handler):
|
||
def __init__(self, log_func: "RootLogger") -> None:
|
||
logging.Handler.__init__(self)
|
||
self.log_func = log_func
|
||
self.ptn_ftp = re.compile(r"^([0-9a-f:\.]+:[0-9]{1,5})-\[")
|
||
self.ptn_smb_ign = re.compile(r"^(Callback added|Config file parsed)")
|
||
|
||
def __repr__(self) -> str:
|
||
level = logging.getLevelName(self.level)
|
||
return "<%s cpp(%s)>" % (self.__class__.__name__, level)
|
||
|
||
def flush(self) -> None:
|
||
pass
|
||
|
||
def emit(self, record: logging.LogRecord) -> None:
|
||
msg = self.format(record)
|
||
lv = record.levelno
|
||
if lv < logging.INFO:
|
||
c = 6
|
||
elif lv < logging.WARNING:
|
||
c = 0
|
||
elif lv < logging.ERROR:
|
||
c = 3
|
||
else:
|
||
c = 1
|
||
|
||
if record.name == "pyftpdlib":
|
||
m = self.ptn_ftp.match(msg)
|
||
if m:
|
||
ip = m.group(1)
|
||
msg = msg[len(ip) + 1 :]
|
||
if ip.startswith("::ffff:"):
|
||
record.name = ip[7:]
|
||
else:
|
||
record.name = ip
|
||
elif record.name.startswith("impacket"):
|
||
if self.ptn_smb_ign.match(msg):
|
||
return
|
||
elif record.name.startswith("partftpy."):
|
||
record.name = record.name[9:]
|
||
|
||
self.log_func(record.name[-21:], msg, c)
|
||
|
||
|
||
class NetMap(object):
|
||
def __init__(
|
||
self,
|
||
ips: list[str],
|
||
cidrs: list[str],
|
||
keep_lo=False,
|
||
strict_cidr=False,
|
||
defer_mutex=False,
|
||
) -> None:
|
||
"""
|
||
ips: list of plain ipv4/ipv6 IPs, not cidr
|
||
cidrs: list of cidr-notation IPs (ip/prefix)
|
||
"""
|
||
|
||
# fails multiprocessing; defer assignment
|
||
self.mutex: Optional[threading.Lock] = None if defer_mutex else threading.Lock()
|
||
|
||
if "::" in ips:
|
||
ips = [x for x in ips if x != "::"] + list(
|
||
[x.split("/")[0] for x in cidrs if ":" in x]
|
||
)
|
||
ips.append("0.0.0.0")
|
||
|
||
if "0.0.0.0" in ips:
|
||
ips = [x for x in ips if x != "0.0.0.0"] + list(
|
||
[x.split("/")[0] for x in cidrs if ":" not in x]
|
||
)
|
||
|
||
if not keep_lo:
|
||
ips = [x for x in ips if x not in ("::1", "127.0.0.1")]
|
||
|
||
ips = find_prefix(ips, cidrs)
|
||
|
||
self.cache: dict[str, str] = {}
|
||
self.b2sip: dict[bytes, str] = {}
|
||
self.b2net: dict[bytes, Union[IPv4Network, IPv6Network]] = {}
|
||
self.bip: list[bytes] = []
|
||
for ip in ips:
|
||
v6 = ":" in ip
|
||
fam = socket.AF_INET6 if v6 else socket.AF_INET
|
||
bip = socket.inet_pton(fam, ip.split("/")[0])
|
||
self.bip.append(bip)
|
||
self.b2sip[bip] = ip.split("/")[0]
|
||
self.b2net[bip] = (IPv6Network if v6 else IPv4Network)(ip, strict_cidr)
|
||
|
||
self.bip.sort(reverse=True)
|
||
|
||
def map(self, ip: str) -> str:
|
||
if ip.startswith("::ffff:"):
|
||
ip = ip[7:]
|
||
|
||
try:
|
||
return self.cache[ip]
|
||
except:
|
||
# intentionally crash the calling thread if unset:
|
||
assert self.mutex # type: ignore # !rm
|
||
|
||
with self.mutex:
|
||
return self._map(ip)
|
||
|
||
def _map(self, ip: str) -> str:
|
||
v6 = ":" in ip
|
||
ci = IPv6Address(ip) if v6 else IPv4Address(ip)
|
||
bip = next((x for x in self.bip if ci in self.b2net[x]), None)
|
||
ret = self.b2sip[bip] if bip else ""
|
||
if len(self.cache) > 9000:
|
||
self.cache = {}
|
||
self.cache[ip] = ret
|
||
return ret
|
||
|
||
|
||
class UnrecvEOF(OSError):
|
||
pass
|
||
|
||
|
||
class _Unrecv(object):
|
||
"""
|
||
undo any number of socket recv ops
|
||
"""
|
||
|
||
def __init__(self, s: socket.socket, log: Optional["NamedLogger"]) -> None:
|
||
self.s = s
|
||
self.log = log
|
||
self.buf: bytes = b""
|
||
|
||
def recv(self, nbytes: int, spins: int = 1) -> bytes:
|
||
if self.buf:
|
||
ret = self.buf[:nbytes]
|
||
self.buf = self.buf[nbytes:]
|
||
return ret
|
||
|
||
while True:
|
||
try:
|
||
ret = self.s.recv(nbytes)
|
||
break
|
||
except socket.timeout:
|
||
spins -= 1
|
||
if spins <= 0:
|
||
ret = b""
|
||
break
|
||
continue
|
||
except:
|
||
ret = b""
|
||
break
|
||
|
||
if not ret:
|
||
raise UnrecvEOF("client stopped sending data")
|
||
|
||
return ret
|
||
|
||
def recv_ex(self, nbytes: int, raise_on_trunc: bool = True) -> bytes:
|
||
"""read an exact number of bytes"""
|
||
ret = b""
|
||
try:
|
||
while nbytes > len(ret):
|
||
ret += self.recv(nbytes - len(ret))
|
||
except OSError:
|
||
t = "client stopped sending data; expected at least %d more bytes"
|
||
if not ret:
|
||
t = t % (nbytes,)
|
||
else:
|
||
t += ", only got %d"
|
||
t = t % (nbytes, len(ret))
|
||
if len(ret) <= 16:
|
||
t += "; %r" % (ret,)
|
||
|
||
if raise_on_trunc:
|
||
raise UnrecvEOF(5, t)
|
||
elif self.log:
|
||
self.log(t, 3)
|
||
|
||
return ret
|
||
|
||
def unrecv(self, buf: bytes) -> None:
|
||
self.buf = buf + self.buf
|
||
|
||
|
||
# !rm.yes>
|
||
class _LUnrecv(object):
|
||
"""
|
||
with expensive debug logging
|
||
"""
|
||
|
||
def __init__(self, s: socket.socket, log: Optional["NamedLogger"]) -> None:
|
||
self.s = s
|
||
self.log = log
|
||
self.buf = b""
|
||
|
||
def recv(self, nbytes: int, spins: int) -> bytes:
|
||
if self.buf:
|
||
ret = self.buf[:nbytes]
|
||
self.buf = self.buf[nbytes:]
|
||
t = "\033[0;7mur:pop:\033[0;1;32m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m"
|
||
print(t.format(ret, self.buf))
|
||
return ret
|
||
|
||
ret = self.s.recv(nbytes)
|
||
t = "\033[0;7mur:recv\033[0;1;33m {}\033[0m"
|
||
print(t.format(ret))
|
||
if not ret:
|
||
raise UnrecvEOF("client stopped sending data")
|
||
|
||
return ret
|
||
|
||
def recv_ex(self, nbytes: int, raise_on_trunc: bool = True) -> bytes:
|
||
"""read an exact number of bytes"""
|
||
try:
|
||
ret = self.recv(nbytes, 1)
|
||
err = False
|
||
except:
|
||
ret = b""
|
||
err = True
|
||
|
||
while not err and len(ret) < nbytes:
|
||
try:
|
||
ret += self.recv(nbytes - len(ret), 1)
|
||
except OSError:
|
||
err = True
|
||
|
||
if err:
|
||
t = "client only sent {} of {} expected bytes".format(len(ret), nbytes)
|
||
if raise_on_trunc:
|
||
raise UnrecvEOF(t)
|
||
elif self.log:
|
||
self.log(t, 3)
|
||
|
||
return ret
|
||
|
||
def unrecv(self, buf: bytes) -> None:
|
||
self.buf = buf + self.buf
|
||
t = "\033[0;7mur:push\033[0;1;31m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m"
|
||
print(t.format(buf, self.buf))
|
||
|
||
|
||
# !rm.no>
|
||
|
||
|
||
Unrecv = _Unrecv
|
||
|
||
|
||
class CachedSet(object):
|
||
def __init__(self, maxage: float) -> None:
|
||
self.c: dict[Any, float] = {}
|
||
self.maxage = maxage
|
||
self.oldest = 0.0
|
||
|
||
def add(self, v: Any) -> None:
|
||
self.c[v] = time.time()
|
||
|
||
def cln(self) -> None:
|
||
now = time.time()
|
||
if now - self.oldest < self.maxage:
|
||
return
|
||
|
||
c = self.c = {k: v for k, v in self.c.items() if now - v < self.maxage}
|
||
try:
|
||
self.oldest = c[min(c, key=c.get)] # type: ignore
|
||
except:
|
||
self.oldest = now
|
||
|
||
|
||
class CachedDict(object):
|
||
def __init__(self, maxage: float) -> None:
|
||
self.c: dict[str, tuple[float, Any]] = {}
|
||
self.maxage = maxage
|
||
self.oldest = 0.0
|
||
|
||
def set(self, k: str, v: Any) -> None:
|
||
now = time.time()
|
||
self.c[k] = (now, v)
|
||
if now - self.oldest < self.maxage:
|
||
return
|
||
|
||
c = self.c = {k: v for k, v in self.c.items() if now - v[0] < self.maxage}
|
||
try:
|
||
self.oldest = min([x[0] for x in c.values()])
|
||
except:
|
||
self.oldest = now
|
||
|
||
def get(self, k: str) -> Optional[tuple[str, Any]]:
|
||
try:
|
||
ts, ret = self.c[k]
|
||
now = time.time()
|
||
if now - ts > self.maxage:
|
||
del self.c[k]
|
||
return None
|
||
return ret
|
||
except:
|
||
return None
|
||
|
||
|
||
class FHC(object):
|
||
class CE(object):
|
||
def __init__(self, fh: typing.BinaryIO) -> None:
|
||
self.ts: float = 0
|
||
self.fhs = [fh]
|
||
self.all_fhs = set([fh])
|
||
|
||
def __init__(self) -> None:
|
||
self.cache: dict[str, FHC.CE] = {}
|
||
self.aps: dict[str, int] = {}
|
||
|
||
def close(self, path: str) -> None:
|
||
try:
|
||
ce = self.cache[path]
|
||
except:
|
||
return
|
||
|
||
for fh in ce.fhs:
|
||
fh.close()
|
||
|
||
del self.cache[path]
|
||
del self.aps[path]
|
||
|
||
def clean(self) -> None:
|
||
if not self.cache:
|
||
return
|
||
|
||
keep = {}
|
||
now = time.time()
|
||
for path, ce in self.cache.items():
|
||
if now < ce.ts + 5:
|
||
keep[path] = ce
|
||
else:
|
||
for fh in ce.fhs:
|
||
fh.close()
|
||
|
||
self.cache = keep
|
||
|
||
def pop(self, path: str) -> typing.BinaryIO:
|
||
return self.cache[path].fhs.pop()
|
||
|
||
def put(self, path: str, fh: typing.BinaryIO) -> None:
|
||
if path not in self.aps:
|
||
self.aps[path] = 0
|
||
|
||
try:
|
||
ce = self.cache[path]
|
||
ce.all_fhs.add(fh)
|
||
ce.fhs.append(fh)
|
||
except:
|
||
ce = self.CE(fh)
|
||
self.cache[path] = ce
|
||
|
||
ce.ts = time.time()
|
||
|
||
|
||
class ProgressPrinter(threading.Thread):
|
||
"""
|
||
periodically print progress info without linefeeds
|
||
"""
|
||
|
||
def __init__(self, log: "NamedLogger", args: argparse.Namespace) -> None:
|
||
threading.Thread.__init__(self, name="pp")
|
||
self.daemon = True
|
||
self.log = log
|
||
self.args = args
|
||
self.msg = ""
|
||
self.end = False
|
||
self.n = -1
|
||
|
||
def run(self) -> None:
|
||
sigblock()
|
||
tp = 0
|
||
msg = None
|
||
no_stdout = self.args.q
|
||
fmt = " {}\033[K\r" if VT100 else " {} $\r"
|
||
while not self.end:
|
||
time.sleep(0.1)
|
||
if msg == self.msg or self.end:
|
||
continue
|
||
|
||
msg = self.msg
|
||
now = time.time()
|
||
if msg and now - tp > 10:
|
||
tp = now
|
||
self.log("progress: %s" % (msg,), 6)
|
||
|
||
if no_stdout:
|
||
continue
|
||
|
||
uprint(fmt.format(msg))
|
||
if PY2:
|
||
sys.stdout.flush()
|
||
|
||
if no_stdout:
|
||
return
|
||
|
||
if VT100:
|
||
print("\033[K", end="")
|
||
elif msg:
|
||
print("------------------------")
|
||
|
||
sys.stdout.flush() # necessary on win10 even w/ stderr btw
|
||
|
||
|
||
class MTHash(object):
|
||
def __init__(self, cores: int):
|
||
self.pp: Optional[ProgressPrinter] = None
|
||
self.f: Optional[typing.BinaryIO] = None
|
||
self.sz = 0
|
||
self.csz = 0
|
||
self.stop = False
|
||
self.readsz = 1024 * 1024 * (2 if (RAM_AVAIL or 2) < 1 else 12)
|
||
self.omutex = threading.Lock()
|
||
self.imutex = threading.Lock()
|
||
self.work_q: Queue[int] = Queue()
|
||
self.done_q: Queue[tuple[int, str, int, int]] = Queue()
|
||
self.thrs = []
|
||
for n in range(cores):
|
||
t = Daemon(self.worker, "mth-" + str(n))
|
||
self.thrs.append(t)
|
||
|
||
def hash(
|
||
self,
|
||
f: typing.BinaryIO,
|
||
fsz: int,
|
||
chunksz: int,
|
||
pp: Optional[ProgressPrinter] = None,
|
||
prefix: str = "",
|
||
suffix: str = "",
|
||
) -> list[tuple[str, int, int]]:
|
||
with self.omutex:
|
||
self.f = f
|
||
self.sz = fsz
|
||
self.csz = chunksz
|
||
|
||
chunks: dict[int, tuple[str, int, int]] = {}
|
||
nchunks = int(math.ceil(fsz / chunksz))
|
||
for nch in range(nchunks):
|
||
self.work_q.put(nch)
|
||
|
||
ex = ""
|
||
for nch in range(nchunks):
|
||
qe = self.done_q.get()
|
||
try:
|
||
nch, dig, ofs, csz = qe
|
||
chunks[nch] = (dig, ofs, csz)
|
||
except:
|
||
ex = ex or str(qe)
|
||
|
||
if pp:
|
||
mb = (fsz - nch * chunksz) // (1024 * 1024)
|
||
pp.msg = prefix + str(mb) + suffix
|
||
|
||
if ex:
|
||
raise Exception(ex)
|
||
|
||
ret = []
|
||
for n in range(nchunks):
|
||
ret.append(chunks[n])
|
||
|
||
self.f = None
|
||
self.csz = 0
|
||
self.sz = 0
|
||
return ret
|
||
|
||
def worker(self) -> None:
|
||
while True:
|
||
ofs = self.work_q.get()
|
||
try:
|
||
v = self.hash_at(ofs)
|
||
except Exception as ex:
|
||
v = str(ex) # type: ignore
|
||
|
||
self.done_q.put(v)
|
||
|
||
def hash_at(self, nch: int) -> tuple[int, str, int, int]:
|
||
f = self.f
|
||
ofs = ofs0 = nch * self.csz
|
||
chunk_sz = chunk_rem = min(self.csz, self.sz - ofs)
|
||
if self.stop:
|
||
return nch, "", ofs0, chunk_sz
|
||
|
||
assert f # !rm
|
||
hashobj = hashlib.sha512()
|
||
while chunk_rem > 0:
|
||
with self.imutex:
|
||
f.seek(ofs)
|
||
buf = f.read(min(chunk_rem, self.readsz))
|
||
|
||
if not buf:
|
||
raise Exception("EOF at " + str(ofs))
|
||
|
||
hashobj.update(buf)
|
||
chunk_rem -= len(buf)
|
||
ofs += len(buf)
|
||
|
||
bdig = hashobj.digest()[:33]
|
||
udig = ub64enc(bdig).decode("ascii")
|
||
return nch, udig, ofs0, chunk_sz
|
||
|
||
|
||
class HMaccas(object):
|
||
def __init__(self, keypath: str, retlen: int) -> None:
|
||
self.retlen = retlen
|
||
self.cache: dict[bytes, str] = {}
|
||
try:
|
||
with open(keypath, "rb") as f:
|
||
self.key = f.read()
|
||
if len(self.key) != 64:
|
||
raise Exception()
|
||
except:
|
||
self.key = os.urandom(64)
|
||
with open(keypath, "wb") as f:
|
||
f.write(self.key)
|
||
|
||
def b(self, msg: bytes) -> str:
|
||
try:
|
||
return self.cache[msg]
|
||
except:
|
||
if len(self.cache) > 9000:
|
||
self.cache = {}
|
||
|
||
zb = hmac.new(self.key, msg, hashlib.sha512).digest()
|
||
zs = ub64enc(zb)[: self.retlen].decode("ascii")
|
||
self.cache[msg] = zs
|
||
return zs
|
||
|
||
def s(self, msg: str) -> str:
|
||
return self.b(msg.encode("utf-8", "replace"))
|
||
|
||
|
||
class Magician(object):
|
||
def __init__(self) -> None:
|
||
self.bad_magic = False
|
||
self.mutex = threading.Lock()
|
||
self.magic: Optional["magic.Magic"] = None
|
||
|
||
def ext(self, fpath: str) -> str:
|
||
import magic
|
||
|
||
try:
|
||
if self.bad_magic:
|
||
raise Exception()
|
||
|
||
if not self.magic:
|
||
try:
|
||
with self.mutex:
|
||
if not self.magic:
|
||
self.magic = magic.Magic(uncompress=False, extension=True)
|
||
except:
|
||
self.bad_magic = True
|
||
raise
|
||
|
||
with self.mutex:
|
||
ret = self.magic.from_file(fpath)
|
||
except:
|
||
ret = "?"
|
||
|
||
ret = ret.split("/")[0]
|
||
ret = MAGIC_MAP.get(ret, ret)
|
||
if "?" not in ret:
|
||
return ret
|
||
|
||
mime = magic.from_file(fpath, mime=True)
|
||
mime = re.split("[; ]", mime, maxsplit=1)[0]
|
||
try:
|
||
return EXTS[mime]
|
||
except:
|
||
pass
|
||
|
||
mg = mimetypes.guess_extension(mime)
|
||
if mg:
|
||
return mg[1:]
|
||
else:
|
||
raise Exception()
|
||
|
||
|
||
class Garda(object):
|
||
"""ban clients for repeated offenses"""
|
||
|
||
def __init__(self, cfg: str, uniq: bool = True) -> None:
|
||
self.uniq = uniq
|
||
try:
|
||
a, b, c = cfg.strip().split(",")
|
||
self.lim = int(a)
|
||
self.win = int(b) * 60
|
||
self.pen = int(c) * 60
|
||
except:
|
||
self.lim = self.win = self.pen = 0
|
||
|
||
self.ct: dict[str, list[int]] = {}
|
||
self.prev: dict[str, str] = {}
|
||
self.last_cln = 0
|
||
|
||
def cln(self, ip: str) -> None:
|
||
n = 0
|
||
ok = int(time.time() - self.win)
|
||
for v in self.ct[ip]:
|
||
if v < ok:
|
||
n += 1
|
||
else:
|
||
break
|
||
if n:
|
||
te = self.ct[ip][n:]
|
||
if te:
|
||
self.ct[ip] = te
|
||
else:
|
||
del self.ct[ip]
|
||
try:
|
||
del self.prev[ip]
|
||
except:
|
||
pass
|
||
|
||
def allcln(self) -> None:
|
||
for k in list(self.ct):
|
||
self.cln(k)
|
||
|
||
self.last_cln = int(time.time())
|
||
|
||
def bonk(self, ip: str, prev: str) -> tuple[int, str]:
|
||
if not self.lim:
|
||
return 0, ip
|
||
|
||
if ":" in ip:
|
||
# assume /64 clients; drop 4 groups
|
||
ip = IPv6Address(ip).exploded[:-20]
|
||
|
||
if prev and self.uniq:
|
||
if self.prev.get(ip) == prev:
|
||
return 0, ip
|
||
|
||
self.prev[ip] = prev
|
||
|
||
now = int(time.time())
|
||
try:
|
||
self.ct[ip].append(now)
|
||
except:
|
||
self.ct[ip] = [now]
|
||
|
||
if now - self.last_cln > 300:
|
||
self.allcln()
|
||
else:
|
||
self.cln(ip)
|
||
|
||
if len(self.ct[ip]) >= self.lim:
|
||
return now + self.pen, ip
|
||
else:
|
||
return 0, ip
|
||
|
||
|
||
if WINDOWS and sys.version_info < (3, 8):
|
||
_popen = sp.Popen
|
||
|
||
def _spopen(c, *a, **ka):
|
||
enc = sys.getfilesystemencoding()
|
||
c = [x.decode(enc, "replace") if hasattr(x, "decode") else x for x in c]
|
||
return _popen(c, *a, **ka)
|
||
|
||
sp.Popen = _spopen
|
||
|
||
|
||
def uprint(msg: str) -> None:
|
||
try:
|
||
print(msg, end="")
|
||
except UnicodeEncodeError:
|
||
try:
|
||
print(msg.encode("utf-8", "replace").decode(), end="")
|
||
except:
|
||
print(msg.encode("ascii", "replace").decode(), end="")
|
||
|
||
|
||
def nuprint(msg: str) -> None:
|
||
uprint("%s\n" % (msg,))
|
||
|
||
|
||
def dedent(txt: str) -> str:
|
||
pad = 64
|
||
lns = txt.replace("\r", "").split("\n")
|
||
for ln in lns:
|
||
zs = ln.lstrip()
|
||
pad2 = len(ln) - len(zs)
|
||
if zs and pad > pad2:
|
||
pad = pad2
|
||
return "\n".join([ln[pad:] for ln in lns])
|
||
|
||
|
||
def rice_tid() -> str:
|
||
tid = threading.current_thread().ident
|
||
c = sunpack(b"B" * 5, spack(b">Q", tid)[-5:])
|
||
return "".join("\033[1;37;48;5;{0}m{0:02x}".format(x) for x in c) + "\033[0m"
|
||
|
||
|
||
def trace(*args: Any, **kwargs: Any) -> None:
|
||
t = time.time()
|
||
stack = "".join(
|
||
"\033[36m%s\033[33m%s" % (x[0].split(os.sep)[-1][:-3], x[1])
|
||
for x in traceback.extract_stack()[3:-1]
|
||
)
|
||
parts = ["%.6f" % (t,), rice_tid(), stack]
|
||
|
||
if args:
|
||
parts.append(repr(args))
|
||
|
||
if kwargs:
|
||
parts.append(repr(kwargs))
|
||
|
||
msg = "\033[0m ".join(parts)
|
||
# _tracebuf.append(msg)
|
||
nuprint(msg)
|
||
|
||
|
||
def alltrace() -> str:
|
||
threads: dict[str, types.FrameType] = {}
|
||
names = dict([(t.ident, t.name) for t in threading.enumerate()])
|
||
for tid, stack in sys._current_frames().items():
|
||
name = "%s (%x)" % (names.get(tid), tid)
|
||
threads[name] = stack
|
||
|
||
rret: list[str] = []
|
||
bret: list[str] = []
|
||
for name, stack in sorted(threads.items()):
|
||
ret = ["\n\n# %s" % (name,)]
|
||
pad = None
|
||
for fn, lno, name, line in traceback.extract_stack(stack):
|
||
fn = os.sep.join(fn.split(os.sep)[-3:])
|
||
ret.append('File: "%s", line %d, in %s' % (fn, lno, name))
|
||
if line:
|
||
ret.append(" " + str(line.strip()))
|
||
if "self.not_empty.wait()" in line:
|
||
pad = " " * 4
|
||
|
||
if pad:
|
||
bret += [ret[0]] + [pad + x for x in ret[1:]]
|
||
else:
|
||
rret.extend(ret)
|
||
|
||
return "\n".join(rret + bret) + "\n"
|
||
|
||
|
||
def start_stackmon(arg_str: str, nid: int) -> None:
|
||
suffix = "-{}".format(nid) if nid else ""
|
||
fp, f = arg_str.rsplit(",", 1)
|
||
zi = int(f)
|
||
Daemon(stackmon, "stackmon" + suffix, (fp, zi, suffix))
|
||
|
||
|
||
def stackmon(fp: str, ival: float, suffix: str) -> None:
|
||
ctr = 0
|
||
fp0 = fp
|
||
while True:
|
||
ctr += 1
|
||
fp = fp0
|
||
time.sleep(ival)
|
||
st = "{}, {}\n{}".format(ctr, time.time(), alltrace())
|
||
buf = st.encode("utf-8", "replace")
|
||
|
||
if fp.endswith(".gz"):
|
||
import gzip
|
||
|
||
# 2459b 2304b 2241b 2202b 2194b 2191b lv3..8
|
||
# 0.06s 0.08s 0.11s 0.13s 0.16s 0.19s
|
||
buf = gzip.compress(buf, compresslevel=6)
|
||
|
||
elif fp.endswith(".xz"):
|
||
import lzma
|
||
|
||
# 2276b 2216b 2200b 2192b 2168b lv0..4
|
||
# 0.04s 0.10s 0.22s 0.41s 0.70s
|
||
buf = lzma.compress(buf, preset=0)
|
||
|
||
if "%" in fp:
|
||
dt = datetime.now(UTC)
|
||
for fs in "YmdHMS":
|
||
fs = "%" + fs
|
||
if fs in fp:
|
||
fp = fp.replace(fs, dt.strftime(fs))
|
||
|
||
if "/" in fp:
|
||
try:
|
||
os.makedirs(fp.rsplit("/", 1)[0])
|
||
except:
|
||
pass
|
||
|
||
with open(fp + suffix, "wb") as f:
|
||
f.write(buf)
|
||
|
||
|
||
def start_log_thrs(
|
||
logger: Callable[[str, str, int], None], ival: float, nid: int
|
||
) -> None:
|
||
ival = float(ival)
|
||
tname = lname = "log-thrs"
|
||
if nid:
|
||
tname = "logthr-n{}-i{:x}".format(nid, os.getpid())
|
||
lname = tname[3:]
|
||
|
||
Daemon(log_thrs, tname, (logger, ival, lname))
|
||
|
||
|
||
def log_thrs(log: Callable[[str, str, int], None], ival: float, name: str) -> None:
|
||
while True:
|
||
time.sleep(ival)
|
||
tv = [x.name for x in threading.enumerate()]
|
||
tv = [
|
||
x.split("-")[0]
|
||
if x.split("-")[0] in ["httpconn", "thumb", "tagger"]
|
||
else "listen"
|
||
if "-listen-" in x
|
||
else x
|
||
for x in tv
|
||
if not x.startswith("pydevd.")
|
||
]
|
||
tv = ["{}\033[36m{}".format(v, k) for k, v in sorted(Counter(tv).items())]
|
||
log(name, "\033[0m \033[33m".join(tv), 3)
|
||
|
||
|
||
def sigblock():
|
||
if ANYWIN or PY2:
|
||
return
|
||
|
||
signal.pthread_sigmask(
|
||
signal.SIG_BLOCK, [signal.SIGINT, signal.SIGTERM, signal.SIGUSR1]
|
||
)
|
||
|
||
|
||
def vol_san(vols: list["VFS"], txt: bytes) -> bytes:
|
||
txt0 = txt
|
||
for vol in vols:
|
||
bap = vol.realpath.encode("utf-8")
|
||
bhp = vol.histpath.encode("utf-8")
|
||
bvp = vol.vpath.encode("utf-8")
|
||
bvph = b"$hist(/" + bvp + b")"
|
||
|
||
txt = txt.replace(bap, bvp)
|
||
txt = txt.replace(bhp, bvph)
|
||
txt = txt.replace(bap.replace(b"\\", b"\\\\"), bvp)
|
||
txt = txt.replace(bhp.replace(b"\\", b"\\\\"), bvph)
|
||
|
||
if txt != txt0:
|
||
txt += b"\r\nNOTE: filepaths sanitized; see serverlog for correct values"
|
||
|
||
return txt
|
||
|
||
|
||
def min_ex(max_lines: int = 8, reverse: bool = False) -> str:
|
||
et, ev, tb = sys.exc_info()
|
||
stb = traceback.extract_tb(tb) if tb else traceback.extract_stack()[:-1]
|
||
fmt = "%s:%d <%s>: %s"
|
||
ex = [fmt % (fp.split(os.sep)[-1], ln, fun, txt) for fp, ln, fun, txt in stb]
|
||
if et or ev or tb:
|
||
ex.append("[%s] %s" % (et.__name__ if et else "(anonymous)", ev))
|
||
return "\n".join(ex[-max_lines:][:: -1 if reverse else 1])
|
||
|
||
|
||
def ren_open(fname: str, *args: Any, **kwargs: Any) -> tuple[typing.IO[Any], str]:
|
||
fun = kwargs.pop("fun", open)
|
||
fdir = kwargs.pop("fdir", None)
|
||
suffix = kwargs.pop("suffix", None)
|
||
|
||
if fname == os.devnull:
|
||
return fun(fname, *args, **kwargs), fname
|
||
|
||
if suffix:
|
||
ext = fname.split(".")[-1]
|
||
if len(ext) < 7:
|
||
suffix += "." + ext
|
||
|
||
orig_name = fname
|
||
bname = fname
|
||
ext = ""
|
||
while True:
|
||
ofs = bname.rfind(".")
|
||
if ofs < 0 or ofs < len(bname) - 7:
|
||
# doesn't look like an extension anymore
|
||
break
|
||
|
||
ext = bname[ofs:] + ext
|
||
bname = bname[:ofs]
|
||
|
||
asciified = False
|
||
b64 = ""
|
||
while True:
|
||
f = None
|
||
try:
|
||
if fdir:
|
||
fpath = os.path.join(fdir, fname)
|
||
else:
|
||
fpath = fname
|
||
|
||
if suffix and os.path.lexists(fsenc(fpath)):
|
||
fpath += suffix
|
||
fname += suffix
|
||
ext += suffix
|
||
|
||
f = fun(fsenc(fpath), *args, **kwargs)
|
||
if b64:
|
||
assert fdir # !rm
|
||
fp2 = "fn-trunc.%s.txt" % (b64,)
|
||
fp2 = os.path.join(fdir, fp2)
|
||
with open(fsenc(fp2), "wb") as f2:
|
||
f2.write(orig_name.encode("utf-8"))
|
||
|
||
return f, fname
|
||
|
||
except OSError as ex_:
|
||
ex = ex_
|
||
if f:
|
||
f.close()
|
||
|
||
# EPERM: android13
|
||
if ex.errno in (errno.EINVAL, errno.EPERM) and not asciified:
|
||
asciified = True
|
||
zsl = []
|
||
for zs in (bname, fname):
|
||
zs = zs.encode("ascii", "replace").decode("ascii")
|
||
zs = re.sub(r"[^][a-zA-Z0-9(){}.,+=!-]", "_", zs)
|
||
zsl.append(zs)
|
||
bname, fname = zsl
|
||
continue
|
||
|
||
# ENOTSUP: zfs on ubuntu 20.04
|
||
if ex.errno not in (errno.ENAMETOOLONG, errno.ENOSR, errno.ENOTSUP) and (
|
||
not WINDOWS or ex.errno != errno.EINVAL
|
||
):
|
||
raise
|
||
|
||
if not b64:
|
||
zs = ("%s\n%s" % (orig_name, suffix)).encode("utf-8", "replace")
|
||
b64 = ub64enc(hashlib.sha512(zs).digest()[:12]).decode("ascii")
|
||
|
||
badlen = len(fname)
|
||
while len(fname) >= badlen:
|
||
if len(bname) < 8:
|
||
raise ex
|
||
|
||
if len(bname) > len(ext):
|
||
# drop the last letter of the filename
|
||
bname = bname[:-1]
|
||
else:
|
||
try:
|
||
# drop the leftmost sub-extension
|
||
_, ext = ext.split(".", 1)
|
||
except:
|
||
# okay do the first letter then
|
||
ext = "." + ext[2:]
|
||
|
||
fname = "%s~%s%s" % (bname, b64, ext)
|
||
|
||
|
||
class MultipartParser(object):
|
||
def __init__(
|
||
self,
|
||
log_func: "NamedLogger",
|
||
args: argparse.Namespace,
|
||
sr: Unrecv,
|
||
http_headers: dict[str, str],
|
||
):
|
||
self.sr = sr
|
||
self.log = log_func
|
||
self.args = args
|
||
self.headers = http_headers
|
||
|
||
self.re_ctype = re.compile(r"^content-type: *([^; ]+)", re.IGNORECASE)
|
||
self.re_cdisp = re.compile(r"^content-disposition: *([^; ]+)", re.IGNORECASE)
|
||
self.re_cdisp_field = re.compile(
|
||
r'^content-disposition:(?: *|.*; *)name="([^"]+)"', re.IGNORECASE
|
||
)
|
||
self.re_cdisp_file = re.compile(
|
||
r'^content-disposition:(?: *|.*; *)filename="(.*)"', re.IGNORECASE
|
||
)
|
||
|
||
self.boundary = b""
|
||
self.gen: Optional[
|
||
Generator[
|
||
tuple[str, Optional[str], Generator[bytes, None, None]], None, None
|
||
]
|
||
] = None
|
||
|
||
def _read_header(self) -> tuple[str, Optional[str]]:
|
||
"""
|
||
returns [fieldname, filename] after eating a block of multipart headers
|
||
while doing a decent job at dealing with the absolute mess that is
|
||
rfc1341/rfc1521/rfc2047/rfc2231/rfc2388/rfc6266/the-real-world
|
||
(only the fallback non-js uploader relies on these filenames)
|
||
"""
|
||
for ln in read_header(self.sr, 2, 2592000):
|
||
self.log(ln)
|
||
|
||
m = self.re_ctype.match(ln)
|
||
if m:
|
||
if m.group(1).lower() == "multipart/mixed":
|
||
# rfc-7578 overrides rfc-2388 so this is not-impl
|
||
# (opera >=9 <11.10 is the only thing i've ever seen use it)
|
||
raise Pebkac(
|
||
400,
|
||
"you can't use that browser to upload multiple files at once",
|
||
)
|
||
|
||
continue
|
||
|
||
# the only other header we care about is content-disposition
|
||
m = self.re_cdisp.match(ln)
|
||
if not m:
|
||
continue
|
||
|
||
if m.group(1).lower() != "form-data":
|
||
raise Pebkac(400, "not form-data: {}".format(ln))
|
||
|
||
try:
|
||
field = self.re_cdisp_field.match(ln).group(1) # type: ignore
|
||
except:
|
||
raise Pebkac(400, "missing field name: {}".format(ln))
|
||
|
||
try:
|
||
fn = self.re_cdisp_file.match(ln).group(1) # type: ignore
|
||
except:
|
||
# this is not a file upload, we're done
|
||
return field, None
|
||
|
||
try:
|
||
is_webkit = "applewebkit" in self.headers["user-agent"].lower()
|
||
except:
|
||
is_webkit = False
|
||
|
||
# chromes ignore the spec and makes this real easy
|
||
if is_webkit:
|
||
# quotes become %22 but they don't escape the %
|
||
# so unescaping the quotes could turn messi
|
||
return field, fn.split('"')[0]
|
||
|
||
# also ez if filename doesn't contain "
|
||
if not fn.split('"')[0].endswith("\\"):
|
||
return field, fn.split('"')[0]
|
||
|
||
# this breaks on firefox uploads that contain \"
|
||
# since firefox escapes " but forgets to escape \
|
||
# so it'll truncate after the \
|
||
ret = ""
|
||
esc = False
|
||
for ch in fn:
|
||
if esc:
|
||
esc = False
|
||
if ch not in ['"', "\\"]:
|
||
ret += "\\"
|
||
ret += ch
|
||
elif ch == "\\":
|
||
esc = True
|
||
elif ch == '"':
|
||
break
|
||
else:
|
||
ret += ch
|
||
|
||
return field, ret
|
||
|
||
raise Pebkac(400, "server expected a multipart header but you never sent one")
|
||
|
||
def _read_data(self) -> Generator[bytes, None, None]:
|
||
blen = len(self.boundary)
|
||
bufsz = self.args.s_rd_sz
|
||
while True:
|
||
try:
|
||
buf = self.sr.recv(bufsz)
|
||
except:
|
||
# abort: client disconnected
|
||
raise Pebkac(400, "client d/c during multipart post")
|
||
|
||
while True:
|
||
ofs = buf.find(self.boundary)
|
||
if ofs != -1:
|
||
self.sr.unrecv(buf[ofs + blen :])
|
||
yield buf[:ofs]
|
||
return
|
||
|
||
d = len(buf) - blen
|
||
if d > 0:
|
||
# buffer growing large; yield everything except
|
||
# the part at the end (maybe start of boundary)
|
||
yield buf[:d]
|
||
buf = buf[d:]
|
||
|
||
# look for boundary near the end of the buffer
|
||
n = 0
|
||
for n in range(1, len(buf) + 1):
|
||
if not buf[-n:] in self.boundary:
|
||
n -= 1
|
||
break
|
||
|
||
if n == 0 or not self.boundary.startswith(buf[-n:]):
|
||
# no boundary contents near the buffer edge
|
||
break
|
||
|
||
if blen == n:
|
||
# EOF: found boundary
|
||
yield buf[:-n]
|
||
return
|
||
|
||
try:
|
||
buf += self.sr.recv(bufsz)
|
||
except:
|
||
# abort: client disconnected
|
||
raise Pebkac(400, "client d/c during multipart post")
|
||
|
||
yield buf
|
||
|
||
def _run_gen(
|
||
self,
|
||
) -> Generator[tuple[str, Optional[str], Generator[bytes, None, None]], None, None]:
|
||
"""
|
||
yields [fieldname, unsanitized_filename, fieldvalue]
|
||
where fieldvalue yields chunks of data
|
||
"""
|
||
run = True
|
||
while run:
|
||
fieldname, filename = self._read_header()
|
||
yield (fieldname, filename, self._read_data())
|
||
|
||
tail = self.sr.recv_ex(2, False)
|
||
|
||
if tail == b"--":
|
||
# EOF indicated by this immediately after final boundary
|
||
tail = self.sr.recv_ex(2, False)
|
||
run = False
|
||
|
||
if tail != b"\r\n":
|
||
t = "protocol error after field value: want b'\\r\\n', got {!r}"
|
||
raise Pebkac(400, t.format(tail))
|
||
|
||
def _read_value(self, iterable: Iterable[bytes], max_len: int) -> bytes:
|
||
ret = b""
|
||
for buf in iterable:
|
||
ret += buf
|
||
if len(ret) > max_len:
|
||
raise Pebkac(422, "field length is too long")
|
||
|
||
return ret
|
||
|
||
def parse(self) -> None:
|
||
boundary = get_boundary(self.headers)
|
||
self.log("boundary=%r" % (boundary,))
|
||
|
||
# spec says there might be junk before the first boundary,
|
||
# can't have the leading \r\n if that's not the case
|
||
self.boundary = b"--" + boundary.encode("utf-8")
|
||
|
||
# discard junk before the first boundary
|
||
for junk in self._read_data():
|
||
if not junk:
|
||
continue
|
||
|
||
jtxt = junk.decode("utf-8", "replace")
|
||
self.log("discarding preamble |%d| %r" % (len(junk), jtxt))
|
||
|
||
# nice, now make it fast
|
||
self.boundary = b"\r\n" + self.boundary
|
||
self.gen = self._run_gen()
|
||
|
||
def require(self, field_name: str, max_len: int) -> str:
|
||
"""
|
||
returns the value of the next field in the multipart body,
|
||
raises if the field name is not as expected
|
||
"""
|
||
assert self.gen # !rm
|
||
p_field, p_fname, p_data = next(self.gen)
|
||
if p_field != field_name:
|
||
raise WrongPostKey(field_name, p_field, p_fname, p_data)
|
||
|
||
return self._read_value(p_data, max_len).decode("utf-8", "surrogateescape")
|
||
|
||
def drop(self) -> None:
|
||
"""discards the remaining multipart body"""
|
||
assert self.gen # !rm
|
||
for _, _, data in self.gen:
|
||
for _ in data:
|
||
pass
|
||
|
||
|
||
def get_boundary(headers: dict[str, str]) -> str:
|
||
# boundaries contain a-z A-Z 0-9 ' ( ) + _ , - . / : = ?
|
||
# (whitespace allowed except as the last char)
|
||
ptn = r"^multipart/form-data *; *(.*; *)?boundary=([^;]+)"
|
||
ct = headers["content-type"]
|
||
m = re.match(ptn, ct, re.IGNORECASE)
|
||
if not m:
|
||
raise Pebkac(400, "invalid content-type for a multipart post: {}".format(ct))
|
||
|
||
return m.group(2)
|
||
|
||
|
||
def read_header(sr: Unrecv, t_idle: int, t_tot: int) -> list[str]:
|
||
t0 = time.time()
|
||
ret = b""
|
||
while True:
|
||
if time.time() - t0 >= t_tot:
|
||
return []
|
||
|
||
try:
|
||
ret += sr.recv(1024, t_idle // 2)
|
||
except:
|
||
if not ret:
|
||
return []
|
||
|
||
raise Pebkac(
|
||
400,
|
||
"protocol error while reading headers",
|
||
log=ret.decode("utf-8", "replace"),
|
||
)
|
||
|
||
ofs = ret.find(b"\r\n\r\n")
|
||
if ofs < 0:
|
||
if len(ret) > 1024 * 32:
|
||
raise Pebkac(400, "header 2big")
|
||
else:
|
||
continue
|
||
|
||
if len(ret) > ofs + 4:
|
||
sr.unrecv(ret[ofs + 4 :])
|
||
|
||
return ret[:ofs].decode("utf-8", "surrogateescape").lstrip("\r\n").split("\r\n")
|
||
|
||
|
||
def rand_name(fdir: str, fn: str, rnd: int) -> str:
|
||
ok = False
|
||
try:
|
||
ext = "." + fn.rsplit(".", 1)[1]
|
||
except:
|
||
ext = ""
|
||
|
||
for extra in range(16):
|
||
for _ in range(16):
|
||
if ok:
|
||
break
|
||
|
||
nc = rnd + extra
|
||
nb = (6 + 6 * nc) // 8
|
||
zb = ub64enc(os.urandom(nb))
|
||
fn = zb[:nc].decode("ascii") + ext
|
||
ok = not os.path.exists(fsenc(os.path.join(fdir, fn)))
|
||
|
||
return fn
|
||
|
||
|
||
def gen_filekey(alg: int, salt: str, fspath: str, fsize: int, inode: int) -> str:
|
||
if alg == 1:
|
||
zs = "%s %s %s %s" % (salt, fspath, fsize, inode)
|
||
else:
|
||
zs = "%s %s" % (salt, fspath)
|
||
|
||
zb = zs.encode("utf-8", "replace")
|
||
return ub64enc(hashlib.sha512(zb).digest()).decode("ascii")
|
||
|
||
|
||
def gen_filekey_dbg(
|
||
alg: int,
|
||
salt: str,
|
||
fspath: str,
|
||
fsize: int,
|
||
inode: int,
|
||
log: "NamedLogger",
|
||
log_ptn: Optional[Pattern[str]],
|
||
) -> str:
|
||
ret = gen_filekey(alg, salt, fspath, fsize, inode)
|
||
|
||
assert log_ptn # !rm
|
||
if log_ptn.search(fspath):
|
||
try:
|
||
import inspect
|
||
|
||
ctx = ",".join(inspect.stack()[n].function for n in range(2, 5))
|
||
except:
|
||
ctx = ""
|
||
|
||
p2 = "a"
|
||
try:
|
||
p2 = absreal(fspath)
|
||
if p2 != fspath:
|
||
raise Exception()
|
||
except:
|
||
t = "maybe wrong abspath for filekey;\norig: {}\nreal: {}"
|
||
log(t.format(fspath, p2), 1)
|
||
|
||
t = "fk({}) salt({}) size({}) inode({}) fspath({}) at({})"
|
||
log(t.format(ret[:8], salt, fsize, inode, fspath, ctx), 5)
|
||
|
||
return ret
|
||
|
||
|
||
WKDAYS = "Mon Tue Wed Thu Fri Sat Sun".split()
|
||
MONTHS = "Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split()
|
||
RFC2822 = "%s, %02d %s %04d %02d:%02d:%02d GMT"
|
||
|
||
|
||
def formatdate(ts: Optional[float] = None) -> str:
|
||
# gmtime ~= datetime.fromtimestamp(ts, UTC).timetuple()
|
||
y, mo, d, h, mi, s, wd, _, _ = time.gmtime(ts)
|
||
return RFC2822 % (WKDAYS[wd], d, MONTHS[mo - 1], y, h, mi, s)
|
||
|
||
|
||
def gencookie(k: str, v: str, r: str, tls: bool, dur: int = 0, txt: str = "") -> str:
|
||
v = v.replace("%", "%25").replace(";", "%3B")
|
||
if dur:
|
||
exp = formatdate(time.time() + dur)
|
||
else:
|
||
exp = "Fri, 15 Aug 1997 01:00:00 GMT"
|
||
|
||
t = "%s=%s; Path=/%s; Expires=%s%s%s; SameSite=Lax"
|
||
return t % (k, v, r, exp, "; Secure" if tls else "", txt)
|
||
|
||
|
||
def humansize(sz: float, terse: bool = False) -> str:
|
||
for unit in HUMANSIZE_UNITS:
|
||
if sz < 1024:
|
||
break
|
||
|
||
sz /= 1024.0
|
||
|
||
if terse:
|
||
return "%s%s" % (str(sz)[:4].rstrip("."), unit[:1])
|
||
else:
|
||
return "%s %s" % (str(sz)[:4].rstrip("."), unit)
|
||
|
||
|
||
def unhumanize(sz: str) -> int:
|
||
try:
|
||
return int(sz)
|
||
except:
|
||
pass
|
||
|
||
mc = sz[-1:].lower()
|
||
mi = UNHUMANIZE_UNITS.get(mc, 1)
|
||
return int(float(sz[:-1]) * mi)
|
||
|
||
|
||
def get_spd(nbyte: int, t0: float, t: Optional[float] = None) -> str:
|
||
if t is None:
|
||
t = time.time()
|
||
|
||
bps = nbyte / ((t - t0) or 0.001)
|
||
s1 = humansize(nbyte).replace(" ", "\033[33m").replace("iB", "")
|
||
s2 = humansize(bps).replace(" ", "\033[35m").replace("iB", "")
|
||
return "%s \033[0m%s/s\033[0m" % (s1, s2)
|
||
|
||
|
||
def s2hms(s: float, optional_h: bool = False) -> str:
|
||
s = int(s)
|
||
h, s = divmod(s, 3600)
|
||
m, s = divmod(s, 60)
|
||
if not h and optional_h:
|
||
return "%d:%02d" % (m, s)
|
||
|
||
return "%d:%02d:%02d" % (h, m, s)
|
||
|
||
|
||
def djoin(*paths: str) -> str:
|
||
"""joins without adding a trailing slash on blank args"""
|
||
return os.path.join(*[x for x in paths if x])
|
||
|
||
|
||
def uncyg(path: str) -> str:
|
||
if len(path) < 2 or not path.startswith("/"):
|
||
return path
|
||
|
||
if len(path) > 2 and path[2] != "/":
|
||
return path
|
||
|
||
return "%s:\\%s" % (path[1], path[3:])
|
||
|
||
|
||
def undot(path: str) -> str:
|
||
ret: list[str] = []
|
||
for node in path.split("/"):
|
||
if node == "." or not node:
|
||
continue
|
||
|
||
if node == "..":
|
||
if ret:
|
||
ret.pop()
|
||
continue
|
||
|
||
ret.append(node)
|
||
|
||
return "/".join(ret)
|
||
|
||
|
||
def sanitize_fn(fn: str, ok: str) -> str:
|
||
if "/" not in ok:
|
||
fn = fn.replace("\\", "/").split("/")[-1]
|
||
|
||
if ANYWIN:
|
||
remap = [
|
||
["<", "<"],
|
||
[">", ">"],
|
||
[":", ":"],
|
||
['"', """],
|
||
["/", "/"],
|
||
["\\", "\"],
|
||
["|", "|"],
|
||
["?", "?"],
|
||
["*", "*"],
|
||
]
|
||
for a, b in [x for x in remap if x[0] not in ok]:
|
||
fn = fn.replace(a, b)
|
||
|
||
bad = ["con", "prn", "aux", "nul"]
|
||
for n in range(1, 10):
|
||
bad += ("com%s lpt%s" % (n, n)).split(" ")
|
||
|
||
if fn.lower().split(".")[0] in bad:
|
||
fn = "_" + fn
|
||
|
||
return fn.strip()
|
||
|
||
|
||
def sanitize_vpath(vp: str, ok: str) -> str:
|
||
parts = vp.replace(os.sep, "/").split("/")
|
||
ret = [sanitize_fn(x, ok) for x in parts]
|
||
return "/".join(ret)
|
||
|
||
|
||
def relchk(rp: str) -> str:
|
||
if "\x00" in rp:
|
||
return "[nul]"
|
||
|
||
if ANYWIN:
|
||
if "\n" in rp or "\r" in rp:
|
||
return "x\nx"
|
||
|
||
p = re.sub(r'[\\:*?"<>|]', "", rp)
|
||
if p != rp:
|
||
return "[{}]".format(p)
|
||
|
||
return ""
|
||
|
||
|
||
def absreal(fpath: str) -> str:
|
||
try:
|
||
return fsdec(os.path.abspath(os.path.realpath(afsenc(fpath))))
|
||
except:
|
||
if not WINDOWS:
|
||
raise
|
||
|
||
# cpython bug introduced in 3.8, still exists in 3.9.1,
|
||
# some win7sp1 and win10:20H2 boxes cannot realpath a
|
||
# networked drive letter such as b"n:" or b"n:\\"
|
||
return os.path.abspath(os.path.realpath(fpath))
|
||
|
||
|
||
def u8safe(txt: str) -> str:
|
||
try:
|
||
return txt.encode("utf-8", "xmlcharrefreplace").decode("utf-8", "replace")
|
||
except:
|
||
return txt.encode("utf-8", "replace").decode("utf-8", "replace")
|
||
|
||
|
||
def exclude_dotfiles(filepaths: list[str]) -> list[str]:
|
||
return [x for x in filepaths if not x.split("/")[-1].startswith(".")]
|
||
|
||
|
||
def odfusion(
|
||
base: Union[ODict[str, bool], ODict["LiteralString", bool]], oth: str
|
||
) -> ODict[str, bool]:
|
||
# merge an "ordered set" (just a dict really) with another list of keys
|
||
words0 = [x for x in oth.split(",") if x]
|
||
words1 = [x for x in oth[1:].split(",") if x]
|
||
|
||
ret = base.copy()
|
||
if oth.startswith("+"):
|
||
for k in words1:
|
||
ret[k] = True
|
||
elif oth[:1] in ("-", "/"):
|
||
for k in words1:
|
||
ret.pop(k, None)
|
||
else:
|
||
ret = ODict.fromkeys(words0, True)
|
||
|
||
return ret
|
||
|
||
|
||
def ipnorm(ip: str) -> str:
|
||
if ":" in ip:
|
||
# assume /64 clients; drop 4 groups
|
||
return IPv6Address(ip).exploded[:-20]
|
||
|
||
return ip
|
||
|
||
|
||
def find_prefix(ips: list[str], cidrs: list[str]) -> list[str]:
|
||
ret = []
|
||
for ip in ips:
|
||
hit = next((x for x in cidrs if x.startswith(ip + "/") or ip == x), None)
|
||
if hit:
|
||
ret.append(hit)
|
||
return ret
|
||
|
||
|
||
def html_escape(s: str, quot: bool = False, crlf: bool = False) -> str:
|
||
"""html.escape but also newlines"""
|
||
s = s.replace("&", "&").replace("<", "<").replace(">", ">")
|
||
if quot:
|
||
s = s.replace('"', """).replace("'", "'")
|
||
if crlf:
|
||
s = s.replace("\r", " ").replace("\n", " ")
|
||
|
||
return s
|
||
|
||
|
||
def html_bescape(s: bytes, quot: bool = False, crlf: bool = False) -> bytes:
|
||
"""html.escape but bytestrings"""
|
||
s = s.replace(b"&", b"&").replace(b"<", b"<").replace(b">", b">")
|
||
if quot:
|
||
s = s.replace(b'"', b""").replace(b"'", b"'")
|
||
if crlf:
|
||
s = s.replace(b"\r", b" ").replace(b"\n", b" ")
|
||
|
||
return s
|
||
|
||
|
||
def _quotep2(txt: str) -> str:
|
||
"""url quoter which deals with bytes correctly"""
|
||
if not txt:
|
||
return ""
|
||
btxt = w8enc(txt)
|
||
quot = quote(btxt, safe=b"/")
|
||
return w8dec(quot.replace(b" ", b"+")) # type: ignore
|
||
|
||
|
||
def _quotep3(txt: str) -> str:
|
||
"""url quoter which deals with bytes correctly"""
|
||
if not txt:
|
||
return ""
|
||
btxt = w8enc(txt)
|
||
quot = quote(btxt, safe=b"/").encode("utf-8")
|
||
return w8dec(quot.replace(b" ", b"+"))
|
||
|
||
|
||
if not PY2:
|
||
_uqsb = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_.-~/"
|
||
_uqtl = {
|
||
n: ("%%%02X" % (n,) if n not in _uqsb else chr(n)).encode("utf-8")
|
||
for n in range(256)
|
||
}
|
||
_uqtl[b" "] = b"+"
|
||
|
||
def _quotep3b(txt: str) -> str:
|
||
"""url quoter which deals with bytes correctly"""
|
||
if not txt:
|
||
return ""
|
||
btxt = w8enc(txt)
|
||
if btxt.rstrip(_uqsb):
|
||
lut = _uqtl
|
||
btxt = b"".join([lut[ch] for ch in btxt])
|
||
return w8dec(btxt)
|
||
|
||
quotep = _quotep3b
|
||
|
||
_hexd = "0123456789ABCDEFabcdef"
|
||
_hex2b = {(a + b).encode(): bytes.fromhex(a + b) for a in _hexd for b in _hexd}
|
||
|
||
def unquote(btxt: bytes) -> bytes:
|
||
h2b = _hex2b
|
||
parts = iter(btxt.split(b"%"))
|
||
ret = [next(parts)]
|
||
for item in parts:
|
||
c = h2b.get(item[:2])
|
||
if c is None:
|
||
ret.append(b"%")
|
||
ret.append(item)
|
||
else:
|
||
ret.append(c)
|
||
ret.append(item[2:])
|
||
return b"".join(ret)
|
||
|
||
from urllib.parse import quote_from_bytes as quote
|
||
else:
|
||
from urllib import quote # type: ignore # pylint: disable=no-name-in-module
|
||
from urllib import unquote # type: ignore # pylint: disable=no-name-in-module
|
||
|
||
quotep = _quotep2
|
||
|
||
|
||
def unquotep(txt: str) -> str:
|
||
"""url unquoter which deals with bytes correctly"""
|
||
btxt = w8enc(txt)
|
||
unq2 = unquote(btxt)
|
||
return w8dec(unq2)
|
||
|
||
|
||
def vroots(vp1: str, vp2: str) -> tuple[str, str]:
|
||
"""
|
||
input("q/w/e/r","a/s/d/e/r") output("/q/w/","/a/s/d/")
|
||
"""
|
||
while vp1 and vp2:
|
||
zt1 = vp1.rsplit("/", 1) if "/" in vp1 else ("", vp1)
|
||
zt2 = vp2.rsplit("/", 1) if "/" in vp2 else ("", vp2)
|
||
if zt1[1] != zt2[1]:
|
||
break
|
||
vp1 = zt1[0]
|
||
vp2 = zt2[0]
|
||
return (
|
||
"/%s/" % (vp1,) if vp1 else "/",
|
||
"/%s/" % (vp2,) if vp2 else "/",
|
||
)
|
||
|
||
|
||
def vsplit(vpath: str) -> tuple[str, str]:
|
||
if "/" not in vpath:
|
||
return "", vpath
|
||
|
||
return vpath.rsplit("/", 1) # type: ignore
|
||
|
||
|
||
# vpath-join
|
||
def vjoin(rd: str, fn: str) -> str:
|
||
if rd and fn:
|
||
return rd + "/" + fn
|
||
else:
|
||
return rd or fn
|
||
|
||
|
||
# url-join
|
||
def ujoin(rd: str, fn: str) -> str:
|
||
if rd and fn:
|
||
return rd.rstrip("/") + "/" + fn.lstrip("/")
|
||
else:
|
||
return rd or fn
|
||
|
||
|
||
def log_reloc(
|
||
log: "NamedLogger",
|
||
re: dict[str, str],
|
||
pm: tuple[str, str, str, tuple["VFS", str]],
|
||
ap: str,
|
||
vp: str,
|
||
fn: str,
|
||
vn: "VFS",
|
||
rem: str,
|
||
) -> None:
|
||
nap, nvp, nfn, (nvn, nrem) = pm
|
||
t = "reloc %s:\nold ap [%s]\nnew ap [%s\033[36m/%s\033[0m]\nold vp [%s]\nnew vp [%s\033[36m/%s\033[0m]\nold fn [%s]\nnew fn [%s]\nold vfs [%s]\nnew vfs [%s]\nold rem [%s]\nnew rem [%s]"
|
||
log(t % (re, ap, nap, nfn, vp, nvp, nfn, fn, nfn, vn.vpath, nvn.vpath, rem, nrem))
|
||
|
||
|
||
def pathmod(
|
||
vfs: "VFS", ap: str, vp: str, mod: dict[str, str]
|
||
) -> Optional[tuple[str, str, str, tuple["VFS", str]]]:
|
||
# vfs: authsrv.vfs
|
||
# ap: original abspath to a file
|
||
# vp: original urlpath to a file
|
||
# mod: modification (ap/vp/fn)
|
||
|
||
nvp = "\n" # new vpath
|
||
ap = os.path.dirname(ap)
|
||
vp, fn = vsplit(vp)
|
||
if mod.get("fn"):
|
||
fn = mod["fn"]
|
||
nvp = vp
|
||
|
||
for ref, k in ((ap, "ap"), (vp, "vp")):
|
||
if k not in mod:
|
||
continue
|
||
|
||
ms = mod[k].replace(os.sep, "/")
|
||
if ms.startswith("/"):
|
||
np = ms
|
||
elif k == "vp":
|
||
np = undot(vjoin(ref, ms))
|
||
else:
|
||
np = os.path.abspath(os.path.join(ref, ms))
|
||
|
||
if k == "vp":
|
||
nvp = np.lstrip("/")
|
||
continue
|
||
|
||
# try to map abspath to vpath
|
||
np = np.replace("/", os.sep)
|
||
for vn_ap, vn in vfs.all_aps:
|
||
if not np.startswith(vn_ap):
|
||
continue
|
||
zs = np[len(vn_ap) :].replace(os.sep, "/")
|
||
nvp = vjoin(vn.vpath, zs)
|
||
break
|
||
|
||
if nvp == "\n":
|
||
return None
|
||
|
||
vn, rem = vfs.get(nvp, "*", False, False)
|
||
if not vn.realpath:
|
||
raise Exception("unmapped vfs")
|
||
|
||
ap = vn.canonical(rem)
|
||
return ap, nvp, fn, (vn, rem)
|
||
|
||
|
||
def _w8dec2(txt: bytes) -> str:
|
||
"""decodes filesystem-bytes to wtf8"""
|
||
return surrogateescape.decodefilename(txt)
|
||
|
||
|
||
def _w8enc2(txt: str) -> bytes:
|
||
"""encodes wtf8 to filesystem-bytes"""
|
||
return surrogateescape.encodefilename(txt)
|
||
|
||
|
||
def _w8dec3(txt: bytes) -> str:
|
||
"""decodes filesystem-bytes to wtf8"""
|
||
return txt.decode(FS_ENCODING, "surrogateescape")
|
||
|
||
|
||
def _w8enc3(txt: str) -> bytes:
|
||
"""encodes wtf8 to filesystem-bytes"""
|
||
return txt.encode(FS_ENCODING, "surrogateescape")
|
||
|
||
|
||
def _msdec(txt: bytes) -> str:
|
||
ret = txt.decode(FS_ENCODING, "surrogateescape")
|
||
return ret[4:] if ret.startswith("\\\\?\\") else ret
|
||
|
||
|
||
def _msaenc(txt: str) -> bytes:
|
||
return txt.replace("/", "\\").encode(FS_ENCODING, "surrogateescape")
|
||
|
||
|
||
def _uncify(txt: str) -> str:
|
||
txt = txt.replace("/", "\\")
|
||
if ":" not in txt and not txt.startswith("\\\\"):
|
||
txt = absreal(txt)
|
||
|
||
return txt if txt.startswith("\\\\") else "\\\\?\\" + txt
|
||
|
||
|
||
def _msenc(txt: str) -> bytes:
|
||
txt = txt.replace("/", "\\")
|
||
if ":" not in txt and not txt.startswith("\\\\"):
|
||
txt = absreal(txt)
|
||
|
||
ret = txt.encode(FS_ENCODING, "surrogateescape")
|
||
return ret if ret.startswith(b"\\\\") else b"\\\\?\\" + ret
|
||
|
||
|
||
w8dec = _w8dec3 if not PY2 else _w8dec2
|
||
w8enc = _w8enc3 if not PY2 else _w8enc2
|
||
|
||
|
||
def w8b64dec(txt: str) -> str:
|
||
"""decodes base64(filesystem-bytes) to wtf8"""
|
||
return w8dec(ub64dec(txt.encode("ascii")))
|
||
|
||
|
||
def w8b64enc(txt: str) -> str:
|
||
"""encodes wtf8 to base64(filesystem-bytes)"""
|
||
return ub64enc(w8enc(txt)).decode("ascii")
|
||
|
||
|
||
if not PY2 and WINDOWS:
|
||
sfsenc = w8enc
|
||
afsenc = _msaenc
|
||
fsenc = _msenc
|
||
fsdec = _msdec
|
||
uncify = _uncify
|
||
elif not PY2 or not WINDOWS:
|
||
fsenc = afsenc = sfsenc = w8enc
|
||
fsdec = w8dec
|
||
uncify = str
|
||
else:
|
||
# moonrunes become \x3f with bytestrings,
|
||
# losing mojibake support is worth
|
||
def _not_actually_mbcs_enc(txt: str) -> bytes:
|
||
return txt # type: ignore
|
||
|
||
def _not_actually_mbcs_dec(txt: bytes) -> str:
|
||
return txt # type: ignore
|
||
|
||
fsenc = afsenc = sfsenc = _not_actually_mbcs_enc
|
||
fsdec = _not_actually_mbcs_dec
|
||
uncify = str
|
||
|
||
|
||
def s3enc(mem_cur: "sqlite3.Cursor", rd: str, fn: str) -> tuple[str, str]:
|
||
ret: list[str] = []
|
||
for v in [rd, fn]:
|
||
try:
|
||
mem_cur.execute("select * from a where b = ?", (v,))
|
||
ret.append(v)
|
||
except:
|
||
ret.append("//" + w8b64enc(v))
|
||
# self.log("mojien [{}] {}".format(v, ret[-1][2:]))
|
||
|
||
return ret[0], ret[1]
|
||
|
||
|
||
def s3dec(rd: str, fn: str) -> tuple[str, str]:
|
||
return (
|
||
w8b64dec(rd[2:]) if rd.startswith("//") else rd,
|
||
w8b64dec(fn[2:]) if fn.startswith("//") else fn,
|
||
)
|
||
|
||
|
||
def db_ex_chk(log: "NamedLogger", ex: Exception, db_path: str) -> bool:
|
||
if str(ex) != "database is locked":
|
||
return False
|
||
|
||
Daemon(lsof, "dbex", (log, db_path))
|
||
return True
|
||
|
||
|
||
def lsof(log: "NamedLogger", abspath: str) -> None:
|
||
try:
|
||
rc, so, se = runcmd([b"lsof", b"-R", fsenc(abspath)], timeout=45)
|
||
zs = (so.strip() + "\n" + se.strip()).strip()
|
||
log("lsof {} = {}\n{}".format(abspath, rc, zs), 3)
|
||
except:
|
||
log("lsof failed; " + min_ex(), 3)
|
||
|
||
|
||
def _fs_mvrm(
|
||
log: "NamedLogger", src: str, dst: str, atomic: bool, flags: dict[str, Any]
|
||
) -> bool:
|
||
bsrc = fsenc(src)
|
||
bdst = fsenc(dst)
|
||
if atomic:
|
||
k = "mv_re_"
|
||
act = "atomic-rename"
|
||
osfun = os.replace
|
||
args = [bsrc, bdst]
|
||
elif dst:
|
||
k = "mv_re_"
|
||
act = "rename"
|
||
osfun = os.rename
|
||
args = [bsrc, bdst]
|
||
else:
|
||
k = "rm_re_"
|
||
act = "delete"
|
||
osfun = os.unlink
|
||
args = [bsrc]
|
||
|
||
maxtime = flags.get(k + "t", 0.0)
|
||
chill = flags.get(k + "r", 0.0)
|
||
if chill < 0.001:
|
||
chill = 0.1
|
||
|
||
ino = 0
|
||
t0 = now = time.time()
|
||
for attempt in range(90210):
|
||
try:
|
||
if ino and os.stat(bsrc).st_ino != ino:
|
||
t = "src inode changed; aborting %s %s"
|
||
log(t % (act, src), 1)
|
||
return False
|
||
if (dst and not atomic) and os.path.exists(bdst):
|
||
t = "something appeared at dst; aborting rename [%s] ==> [%s]"
|
||
log(t % (src, dst), 1)
|
||
return False
|
||
osfun(*args)
|
||
if attempt:
|
||
now = time.time()
|
||
t = "%sd in %.2f sec, attempt %d: %s"
|
||
log(t % (act, now - t0, attempt + 1, src))
|
||
return True
|
||
except OSError as ex:
|
||
now = time.time()
|
||
if ex.errno == errno.ENOENT:
|
||
return False
|
||
if now - t0 > maxtime or attempt == 90209:
|
||
raise
|
||
if not attempt:
|
||
if not PY2:
|
||
ino = os.stat(bsrc).st_ino
|
||
t = "%s failed (err.%d); retrying for %d sec: [%s]"
|
||
log(t % (act, ex.errno, maxtime + 0.99, src))
|
||
|
||
time.sleep(chill)
|
||
|
||
return False # makes pylance happy
|
||
|
||
|
||
def atomic_move(log: "NamedLogger", src: str, dst: str, flags: dict[str, Any]) -> None:
|
||
bsrc = fsenc(src)
|
||
bdst = fsenc(dst)
|
||
if PY2:
|
||
if os.path.exists(bdst):
|
||
_fs_mvrm(log, dst, "", False, flags) # unlink
|
||
|
||
_fs_mvrm(log, src, dst, False, flags) # rename
|
||
elif flags.get("mv_re_t"):
|
||
_fs_mvrm(log, src, dst, True, flags)
|
||
else:
|
||
os.replace(bsrc, bdst)
|
||
|
||
|
||
def wrename(log: "NamedLogger", src: str, dst: str, flags: dict[str, Any]) -> bool:
|
||
if not flags.get("mv_re_t"):
|
||
os.rename(fsenc(src), fsenc(dst))
|
||
return True
|
||
|
||
return _fs_mvrm(log, src, dst, False, flags)
|
||
|
||
|
||
def wunlink(log: "NamedLogger", abspath: str, flags: dict[str, Any]) -> bool:
|
||
if not flags.get("rm_re_t"):
|
||
os.unlink(fsenc(abspath))
|
||
return True
|
||
|
||
return _fs_mvrm(log, abspath, "", False, flags)
|
||
|
||
|
||
def get_df(abspath: str, prune: bool) -> tuple[Optional[int], Optional[int], str]:
|
||
try:
|
||
ap = fsenc(abspath)
|
||
while prune and not os.path.isdir(ap) and BOS_SEP in ap:
|
||
# strip leafs until it hits an existing folder
|
||
ap = ap.rsplit(BOS_SEP, 1)[0]
|
||
|
||
if ANYWIN:
|
||
assert ctypes # type: ignore # !rm
|
||
abspath = fsdec(ap)
|
||
bfree = ctypes.c_ulonglong(0)
|
||
ctypes.windll.kernel32.GetDiskFreeSpaceExW( # type: ignore
|
||
ctypes.c_wchar_p(abspath), None, None, ctypes.pointer(bfree)
|
||
)
|
||
return (bfree.value, None, "")
|
||
else:
|
||
sv = os.statvfs(ap)
|
||
free = sv.f_frsize * sv.f_bfree
|
||
total = sv.f_frsize * sv.f_blocks
|
||
return (free, total, "")
|
||
except Exception as ex:
|
||
return (None, None, repr(ex))
|
||
|
||
|
||
if not ANYWIN and not MACOS:
|
||
|
||
def siocoutq(sck: socket.socket) -> int:
|
||
# SIOCOUTQ^sockios.h == TIOCOUTQ^ioctl.h
|
||
try:
|
||
zb = fcntl.ioctl(sck.fileno(), termios.TIOCOUTQ, b"AAAA")
|
||
return sunpack(b"I", zb)[0] # type: ignore
|
||
except:
|
||
return 1
|
||
|
||
else:
|
||
# macos: getsockopt(fd, SOL_SOCKET, SO_NWRITE, ...)
|
||
# windows: TcpConnectionEstatsSendBuff
|
||
|
||
def siocoutq(sck: socket.socket) -> int:
|
||
return 1
|
||
|
||
|
||
def shut_socket(log: "NamedLogger", sck: socket.socket, timeout: int = 3) -> None:
|
||
t0 = time.time()
|
||
fd = sck.fileno()
|
||
if fd == -1:
|
||
sck.close()
|
||
return
|
||
|
||
try:
|
||
sck.settimeout(timeout)
|
||
sck.shutdown(socket.SHUT_WR)
|
||
try:
|
||
while time.time() - t0 < timeout:
|
||
if not siocoutq(sck):
|
||
# kernel says tx queue empty, we good
|
||
break
|
||
|
||
# on windows in particular, drain rx until client shuts
|
||
if not sck.recv(32 * 1024):
|
||
break
|
||
|
||
sck.shutdown(socket.SHUT_RDWR)
|
||
except:
|
||
pass
|
||
except Exception as ex:
|
||
log("shut({}): {}".format(fd, ex), "90")
|
||
finally:
|
||
td = time.time() - t0
|
||
if td >= 1:
|
||
log("shut({}) in {:.3f} sec".format(fd, td), "90")
|
||
|
||
sck.close()
|
||
|
||
|
||
def read_socket(
|
||
sr: Unrecv, bufsz: int, total_size: int
|
||
) -> Generator[bytes, None, None]:
|
||
remains = total_size
|
||
while remains > 0:
|
||
if bufsz > remains:
|
||
bufsz = remains
|
||
|
||
try:
|
||
buf = sr.recv(bufsz)
|
||
except OSError:
|
||
t = "client d/c during binary post after {} bytes, {} bytes remaining"
|
||
raise Pebkac(400, t.format(total_size - remains, remains))
|
||
|
||
remains -= len(buf)
|
||
yield buf
|
||
|
||
|
||
def read_socket_unbounded(sr: Unrecv, bufsz: int) -> Generator[bytes, None, None]:
|
||
try:
|
||
while True:
|
||
yield sr.recv(bufsz)
|
||
except:
|
||
return
|
||
|
||
|
||
def read_socket_chunked(
|
||
sr: Unrecv, bufsz: int, log: Optional["NamedLogger"] = None
|
||
) -> Generator[bytes, None, None]:
|
||
err = "upload aborted: expected chunk length, got [{}] |{}| instead"
|
||
while True:
|
||
buf = b""
|
||
while b"\r" not in buf:
|
||
try:
|
||
buf += sr.recv(2)
|
||
if len(buf) > 16:
|
||
raise Exception()
|
||
except:
|
||
err = err.format(buf.decode("utf-8", "replace"), len(buf))
|
||
raise Pebkac(400, err)
|
||
|
||
if not buf.endswith(b"\n"):
|
||
sr.recv(1)
|
||
|
||
try:
|
||
chunklen = int(buf.rstrip(b"\r\n"), 16)
|
||
except:
|
||
err = err.format(buf.decode("utf-8", "replace"), len(buf))
|
||
raise Pebkac(400, err)
|
||
|
||
if chunklen == 0:
|
||
x = sr.recv_ex(2, False)
|
||
if x == b"\r\n":
|
||
return
|
||
|
||
t = "protocol error after final chunk: want b'\\r\\n', got {!r}"
|
||
raise Pebkac(400, t.format(x))
|
||
|
||
if log:
|
||
log("receiving %d byte chunk" % (chunklen,))
|
||
|
||
for chunk in read_socket(sr, bufsz, chunklen):
|
||
yield chunk
|
||
|
||
x = sr.recv_ex(2, False)
|
||
if x != b"\r\n":
|
||
t = "protocol error in chunk separator: want b'\\r\\n', got {!r}"
|
||
raise Pebkac(400, t.format(x))
|
||
|
||
|
||
def list_ips() -> list[str]:
|
||
from .stolen.ifaddr import get_adapters
|
||
|
||
ret: set[str] = set()
|
||
for nic in get_adapters():
|
||
for ipo in nic.ips:
|
||
if len(ipo.ip) < 7:
|
||
ret.add(ipo.ip[0]) # ipv6 is (ip,0,0)
|
||
else:
|
||
ret.add(ipo.ip)
|
||
|
||
return list(ret)
|
||
|
||
|
||
def build_netmap(csv: str, defer_mutex: bool = False):
|
||
csv = csv.lower().strip()
|
||
|
||
if csv in ("any", "all", "no", ",", ""):
|
||
return None
|
||
|
||
srcs = [x.strip() for x in csv.split(",") if x.strip()]
|
||
|
||
expanded_shorthands = False
|
||
for shorthand in ("lan", "local", "private", "prvt"):
|
||
if shorthand in srcs:
|
||
if not expanded_shorthands:
|
||
srcs += [
|
||
# lan:
|
||
"10.0.0.0/8",
|
||
"172.16.0.0/12",
|
||
"192.168.0.0/16",
|
||
"fd00::/8",
|
||
# link-local:
|
||
"169.254.0.0/16",
|
||
"fe80::/10",
|
||
# loopback:
|
||
"127.0.0.0/8",
|
||
"::1/128",
|
||
]
|
||
expanded_shorthands = True
|
||
|
||
srcs.remove(shorthand)
|
||
|
||
if not HAVE_IPV6:
|
||
srcs = [x for x in srcs if ":" not in x]
|
||
|
||
cidrs = []
|
||
for zs in srcs:
|
||
if not zs.endswith("."):
|
||
cidrs.append(zs)
|
||
continue
|
||
|
||
# translate old syntax "172.19." => "172.19.0.0/16"
|
||
words = len(zs.rstrip(".").split("."))
|
||
if words == 1:
|
||
zs += "0.0.0/8"
|
||
elif words == 2:
|
||
zs += "0.0/16"
|
||
elif words == 3:
|
||
zs += "0/24"
|
||
else:
|
||
raise Exception("invalid config value [%s]" % (zs,))
|
||
|
||
cidrs.append(zs)
|
||
|
||
ips = [x.split("/")[0] for x in cidrs]
|
||
return NetMap(ips, cidrs, True, False, defer_mutex)
|
||
|
||
|
||
def load_ipu(
|
||
log: "RootLogger", ipus: list[str], defer_mutex: bool = False
|
||
) -> tuple[dict[str, str], NetMap]:
|
||
ip_u = {"": "*"}
|
||
cidr_u = {}
|
||
for ipu in ipus:
|
||
try:
|
||
cidr, uname = ipu.split("=")
|
||
cip, csz = cidr.split("/")
|
||
except:
|
||
t = "\n invalid value %r for argument --ipu; must be CIDR=UNAME (192.168.0.0/16=amelia)"
|
||
raise Exception(t % (ipu,))
|
||
uname2 = cidr_u.get(cidr)
|
||
if uname2 is not None:
|
||
t = "\n invalid value %r for argument --ipu; cidr %s already mapped to %r"
|
||
raise Exception(t % (ipu, cidr, uname2))
|
||
cidr_u[cidr] = uname
|
||
ip_u[cip] = uname
|
||
try:
|
||
nm = NetMap(["::"], list(cidr_u.keys()), True, True, defer_mutex)
|
||
except Exception as ex:
|
||
t = "failed to translate --ipu into netmap, probably due to invalid config: %r"
|
||
log("root", t % (ex,), 1)
|
||
raise
|
||
return ip_u, nm
|
||
|
||
|
||
def yieldfile(fn: str, bufsz: int) -> Generator[bytes, None, None]:
|
||
readsz = min(bufsz, 128 * 1024)
|
||
with open(fsenc(fn), "rb", bufsz) as f:
|
||
while True:
|
||
buf = f.read(readsz)
|
||
if not buf:
|
||
break
|
||
|
||
yield buf
|
||
|
||
|
||
def justcopy(
|
||
fin: Generator[bytes, None, None],
|
||
fout: Union[typing.BinaryIO, typing.IO[Any]],
|
||
hashobj: Optional["hashlib._Hash"],
|
||
max_sz: int,
|
||
slp: float,
|
||
) -> tuple[int, str, str]:
|
||
tlen = 0
|
||
for buf in fin:
|
||
tlen += len(buf)
|
||
if max_sz and tlen > max_sz:
|
||
continue
|
||
|
||
fout.write(buf)
|
||
if slp:
|
||
time.sleep(slp)
|
||
|
||
return tlen, "checksum-disabled", "checksum-disabled"
|
||
|
||
|
||
def hashcopy(
|
||
fin: Generator[bytes, None, None],
|
||
fout: Union[typing.BinaryIO, typing.IO[Any]],
|
||
hashobj: Optional["hashlib._Hash"],
|
||
max_sz: int,
|
||
slp: float,
|
||
) -> tuple[int, str, str]:
|
||
if not hashobj:
|
||
hashobj = hashlib.sha512()
|
||
tlen = 0
|
||
for buf in fin:
|
||
tlen += len(buf)
|
||
if max_sz and tlen > max_sz:
|
||
continue
|
||
|
||
hashobj.update(buf)
|
||
fout.write(buf)
|
||
if slp:
|
||
time.sleep(slp)
|
||
|
||
digest_b64 = ub64enc(hashobj.digest()[:33]).decode("ascii")
|
||
|
||
return tlen, hashobj.hexdigest(), digest_b64
|
||
|
||
|
||
def sendfile_py(
|
||
log: "NamedLogger",
|
||
lower: int,
|
||
upper: int,
|
||
f: typing.BinaryIO,
|
||
s: socket.socket,
|
||
bufsz: int,
|
||
slp: float,
|
||
use_poll: bool,
|
||
dls: dict[str, tuple[float, int]],
|
||
dl_id: str,
|
||
) -> int:
|
||
sent = 0
|
||
remains = upper - lower
|
||
f.seek(lower)
|
||
while remains > 0:
|
||
if slp:
|
||
time.sleep(slp)
|
||
|
||
buf = f.read(min(bufsz, remains))
|
||
if not buf:
|
||
return remains
|
||
|
||
try:
|
||
s.sendall(buf)
|
||
remains -= len(buf)
|
||
except:
|
||
return remains
|
||
|
||
if dl_id:
|
||
sent += len(buf)
|
||
dls[dl_id] = (time.time(), sent)
|
||
|
||
return 0
|
||
|
||
|
||
def sendfile_kern(
|
||
log: "NamedLogger",
|
||
lower: int,
|
||
upper: int,
|
||
f: typing.BinaryIO,
|
||
s: socket.socket,
|
||
bufsz: int,
|
||
slp: float,
|
||
use_poll: bool,
|
||
dls: dict[str, tuple[float, int]],
|
||
dl_id: str,
|
||
) -> int:
|
||
out_fd = s.fileno()
|
||
in_fd = f.fileno()
|
||
ofs = lower
|
||
stuck = 0.0
|
||
if use_poll:
|
||
poll = select.poll()
|
||
poll.register(out_fd, select.POLLOUT)
|
||
|
||
while ofs < upper:
|
||
stuck = stuck or time.time()
|
||
try:
|
||
req = min(0x2000000, upper - ofs) # 32 MiB
|
||
if use_poll:
|
||
poll.poll(10000)
|
||
else:
|
||
select.select([], [out_fd], [], 10)
|
||
n = os.sendfile(out_fd, in_fd, ofs, req)
|
||
stuck = 0
|
||
except OSError as ex:
|
||
# client stopped reading; do another select
|
||
d = time.time() - stuck
|
||
if d < 3600 and ex.errno == errno.EWOULDBLOCK:
|
||
time.sleep(0.02)
|
||
continue
|
||
|
||
n = 0
|
||
except Exception as ex:
|
||
n = 0
|
||
d = time.time() - stuck
|
||
log("sendfile failed after {:.3f} sec: {!r}".format(d, ex))
|
||
|
||
if n <= 0:
|
||
return upper - ofs
|
||
|
||
ofs += n
|
||
if dl_id:
|
||
dls[dl_id] = (time.time(), ofs - lower)
|
||
|
||
# print("sendfile: ok, sent {} now, {} total, {} remains".format(n, ofs - lower, upper - ofs))
|
||
|
||
return 0
|
||
|
||
|
||
def statdir(
|
||
logger: Optional["RootLogger"], scandir: bool, lstat: bool, top: str, throw: bool
|
||
) -> Generator[tuple[str, os.stat_result], None, None]:
|
||
if lstat and ANYWIN:
|
||
lstat = False
|
||
|
||
if lstat and (PY2 or os.stat not in os.supports_follow_symlinks):
|
||
scandir = False
|
||
|
||
src = "statdir"
|
||
try:
|
||
btop = fsenc(top)
|
||
if scandir and hasattr(os, "scandir"):
|
||
src = "scandir"
|
||
with os.scandir(btop) as dh:
|
||
for fh in dh:
|
||
try:
|
||
yield (fsdec(fh.name), fh.stat(follow_symlinks=not lstat))
|
||
except Exception as ex:
|
||
if not logger:
|
||
continue
|
||
|
||
logger(src, "[s] {} @ {}".format(repr(ex), fsdec(fh.path)), 6)
|
||
else:
|
||
src = "listdir"
|
||
fun: Any = os.lstat if lstat else os.stat
|
||
for name in os.listdir(btop):
|
||
abspath = os.path.join(btop, name)
|
||
try:
|
||
yield (fsdec(name), fun(abspath))
|
||
except Exception as ex:
|
||
if not logger:
|
||
continue
|
||
|
||
logger(src, "[s] {} @ {}".format(repr(ex), fsdec(abspath)), 6)
|
||
|
||
except Exception as ex:
|
||
if throw:
|
||
zi = getattr(ex, "errno", 0)
|
||
if zi == errno.ENOENT:
|
||
raise Pebkac(404, str(ex))
|
||
raise
|
||
|
||
t = "{} @ {}".format(repr(ex), top)
|
||
if logger:
|
||
logger(src, t, 1)
|
||
else:
|
||
print(t)
|
||
|
||
|
||
def dir_is_empty(logger: "RootLogger", scandir: bool, top: str):
|
||
for _ in statdir(logger, scandir, False, top, False):
|
||
return False
|
||
return True
|
||
|
||
|
||
def rmdirs(
|
||
logger: "RootLogger", scandir: bool, lstat: bool, top: str, depth: int
|
||
) -> tuple[list[str], list[str]]:
|
||
"""rmdir all descendants, then self"""
|
||
if not os.path.isdir(fsenc(top)):
|
||
top = os.path.dirname(top)
|
||
depth -= 1
|
||
|
||
stats = statdir(logger, scandir, lstat, top, False)
|
||
dirs = [x[0] for x in stats if stat.S_ISDIR(x[1].st_mode)]
|
||
dirs = [os.path.join(top, x) for x in dirs]
|
||
ok = []
|
||
ng = []
|
||
for d in reversed(dirs):
|
||
a, b = rmdirs(logger, scandir, lstat, d, depth + 1)
|
||
ok += a
|
||
ng += b
|
||
|
||
if depth:
|
||
try:
|
||
os.rmdir(fsenc(top))
|
||
ok.append(top)
|
||
except:
|
||
ng.append(top)
|
||
|
||
return ok, ng
|
||
|
||
|
||
def rmdirs_up(top: str, stop: str) -> tuple[list[str], list[str]]:
|
||
"""rmdir on self, then all parents"""
|
||
if top == stop:
|
||
return [], [top]
|
||
|
||
try:
|
||
os.rmdir(fsenc(top))
|
||
except:
|
||
return [], [top]
|
||
|
||
par = os.path.dirname(top)
|
||
if not par or par == stop:
|
||
return [top], []
|
||
|
||
ok, ng = rmdirs_up(par, stop)
|
||
return [top] + ok, ng
|
||
|
||
|
||
def unescape_cookie(orig: str) -> str:
|
||
# mw=idk; doot=qwe%2Crty%3Basd+fgh%2Bjkl%25zxc%26vbn # qwe,rty;asd fgh+jkl%zxc&vbn
|
||
ret = []
|
||
esc = ""
|
||
for ch in orig:
|
||
if ch == "%":
|
||
if esc:
|
||
ret.append(esc)
|
||
esc = ch
|
||
|
||
elif esc:
|
||
esc += ch
|
||
if len(esc) == 3:
|
||
try:
|
||
ret.append(chr(int(esc[1:], 16)))
|
||
except:
|
||
ret.append(esc)
|
||
esc = ""
|
||
|
||
else:
|
||
ret.append(ch)
|
||
|
||
if esc:
|
||
ret.append(esc)
|
||
|
||
return "".join(ret)
|
||
|
||
|
||
def guess_mime(url: str, fallback: str = "application/octet-stream") -> str:
|
||
try:
|
||
ext = url.rsplit(".", 1)[1].lower()
|
||
except:
|
||
return fallback
|
||
|
||
ret = MIMES.get(ext)
|
||
|
||
if not ret:
|
||
x = mimetypes.guess_type(url)
|
||
ret = "application/{}".format(x[1]) if x[1] else x[0]
|
||
|
||
if not ret:
|
||
ret = fallback
|
||
|
||
if ";" not in ret:
|
||
if ret.startswith("text/") or ret.endswith("/javascript"):
|
||
ret += "; charset=utf-8"
|
||
|
||
return ret
|
||
|
||
|
||
def getalive(pids: list[int], pgid: int) -> list[int]:
|
||
alive = []
|
||
for pid in pids:
|
||
try:
|
||
if pgid:
|
||
# check if still one of ours
|
||
if os.getpgid(pid) == pgid:
|
||
alive.append(pid)
|
||
else:
|
||
# windows doesn't have pgroups; assume
|
||
assert psutil # type: ignore # !rm
|
||
psutil.Process(pid)
|
||
alive.append(pid)
|
||
except:
|
||
pass
|
||
|
||
return alive
|
||
|
||
|
||
def killtree(root: int) -> None:
|
||
"""still racy but i tried"""
|
||
try:
|
||
# limit the damage where possible (unixes)
|
||
pgid = os.getpgid(os.getpid())
|
||
except:
|
||
pgid = 0
|
||
|
||
if HAVE_PSUTIL:
|
||
assert psutil # type: ignore # !rm
|
||
pids = [root]
|
||
parent = psutil.Process(root)
|
||
for child in parent.children(recursive=True):
|
||
pids.append(child.pid)
|
||
child.terminate()
|
||
parent.terminate()
|
||
parent = None
|
||
elif pgid:
|
||
# linux-only
|
||
pids = []
|
||
chk = [root]
|
||
while chk:
|
||
pid = chk[0]
|
||
chk = chk[1:]
|
||
pids.append(pid)
|
||
_, t, _ = runcmd(["pgrep", "-P", str(pid)])
|
||
chk += [int(x) for x in t.strip().split("\n") if x]
|
||
|
||
pids = getalive(pids, pgid) # filter to our pgroup
|
||
for pid in pids:
|
||
os.kill(pid, signal.SIGTERM)
|
||
else:
|
||
# windows gets minimal effort sorry
|
||
os.kill(root, signal.SIGTERM)
|
||
return
|
||
|
||
for n in range(10):
|
||
time.sleep(0.1)
|
||
pids = getalive(pids, pgid)
|
||
if not pids or n > 3 and pids == [root]:
|
||
break
|
||
|
||
for pid in pids:
|
||
try:
|
||
os.kill(pid, signal.SIGKILL)
|
||
except:
|
||
pass
|
||
|
||
|
||
def _find_nice() -> str:
|
||
if WINDOWS:
|
||
return "" # use creationflags
|
||
|
||
try:
|
||
zs = shutil.which("nice")
|
||
if zs:
|
||
return zs
|
||
except:
|
||
pass
|
||
|
||
# busted PATHs and/or py2
|
||
for zs in ("/bin", "/sbin", "/usr/bin", "/usr/sbin"):
|
||
zs += "/nice"
|
||
if os.path.exists(zs):
|
||
return zs
|
||
|
||
return ""
|
||
|
||
|
||
NICES = _find_nice()
|
||
NICEB = NICES.encode("utf-8")
|
||
|
||
|
||
def runcmd(
|
||
argv: Union[list[bytes], list[str]], timeout: Optional[float] = None, **ka: Any
|
||
) -> tuple[int, str, str]:
|
||
isbytes = isinstance(argv[0], (bytes, bytearray))
|
||
oom = ka.pop("oom", 0) # 0..1000
|
||
kill = ka.pop("kill", "t") # [t]ree [m]ain [n]one
|
||
capture = ka.pop("capture", 3) # 0=none 1=stdout 2=stderr 3=both
|
||
|
||
sin: Optional[bytes] = ka.pop("sin", None)
|
||
if sin:
|
||
ka["stdin"] = sp.PIPE
|
||
|
||
cout = sp.PIPE if capture in [1, 3] else None
|
||
cerr = sp.PIPE if capture in [2, 3] else None
|
||
bout: bytes
|
||
berr: bytes
|
||
|
||
if ANYWIN:
|
||
if isbytes:
|
||
if argv[0] in CMD_EXEB:
|
||
argv[0] += b".exe"
|
||
else:
|
||
if argv[0] in CMD_EXES:
|
||
argv[0] += ".exe"
|
||
|
||
if ka.pop("nice", None):
|
||
if WINDOWS:
|
||
ka["creationflags"] = 0x4000
|
||
elif NICEB:
|
||
if isbytes:
|
||
argv = [NICEB] + argv
|
||
else:
|
||
argv = [NICES] + argv
|
||
|
||
p = sp.Popen(argv, stdout=cout, stderr=cerr, **ka)
|
||
|
||
if oom and not ANYWIN and not MACOS:
|
||
try:
|
||
with open("/proc/%d/oom_score_adj" % (p.pid,), "wb") as f:
|
||
f.write(("%d\n" % (oom,)).encode("utf-8"))
|
||
except:
|
||
pass
|
||
|
||
if not timeout or PY2:
|
||
bout, berr = p.communicate(sin)
|
||
else:
|
||
try:
|
||
bout, berr = p.communicate(sin, timeout=timeout)
|
||
except sp.TimeoutExpired:
|
||
if kill == "n":
|
||
return -18, "", "" # SIGCONT; leave it be
|
||
elif kill == "m":
|
||
p.kill()
|
||
else:
|
||
killtree(p.pid)
|
||
|
||
try:
|
||
bout, berr = p.communicate(timeout=1)
|
||
except:
|
||
bout = b""
|
||
berr = b""
|
||
|
||
stdout = bout.decode("utf-8", "replace") if cout else ""
|
||
stderr = berr.decode("utf-8", "replace") if cerr else ""
|
||
|
||
rc: int = p.returncode
|
||
if rc is None:
|
||
rc = -14 # SIGALRM; failed to kill
|
||
|
||
return rc, stdout, stderr
|
||
|
||
|
||
def chkcmd(argv: Union[list[bytes], list[str]], **ka: Any) -> tuple[str, str]:
|
||
ok, sout, serr = runcmd(argv, **ka)
|
||
if ok != 0:
|
||
retchk(ok, argv, serr)
|
||
raise Exception(serr)
|
||
|
||
return sout, serr
|
||
|
||
|
||
def mchkcmd(argv: Union[list[bytes], list[str]], timeout: float = 10) -> None:
|
||
if PY2:
|
||
with open(os.devnull, "wb") as f:
|
||
rv = sp.call(argv, stdout=f, stderr=f)
|
||
else:
|
||
rv = sp.call(argv, stdout=sp.DEVNULL, stderr=sp.DEVNULL, timeout=timeout)
|
||
|
||
if rv:
|
||
raise sp.CalledProcessError(rv, (argv[0], b"...", argv[-1]))
|
||
|
||
|
||
def retchk(
|
||
rc: int,
|
||
cmd: Union[list[bytes], list[str]],
|
||
serr: str,
|
||
logger: Optional["NamedLogger"] = None,
|
||
color: Union[int, str] = 0,
|
||
verbose: bool = False,
|
||
) -> None:
|
||
if rc < 0:
|
||
rc = 128 - rc
|
||
|
||
if not rc or rc < 126 and not verbose:
|
||
return
|
||
|
||
s = None
|
||
if rc > 128:
|
||
try:
|
||
s = str(signal.Signals(rc - 128))
|
||
except:
|
||
pass
|
||
elif rc == 126:
|
||
s = "invalid program"
|
||
elif rc == 127:
|
||
s = "program not found"
|
||
elif verbose:
|
||
s = "unknown"
|
||
else:
|
||
s = "invalid retcode"
|
||
|
||
if s:
|
||
t = "{} <{}>".format(rc, s)
|
||
else:
|
||
t = str(rc)
|
||
|
||
try:
|
||
c = " ".join([fsdec(x) for x in cmd]) # type: ignore
|
||
except:
|
||
c = str(cmd)
|
||
|
||
t = "error {} from [{}]".format(t, c)
|
||
if serr:
|
||
t += "\n" + serr
|
||
|
||
if logger:
|
||
logger(t, color)
|
||
else:
|
||
raise Exception(t)
|
||
|
||
|
||
def _parsehook(
|
||
log: Optional["NamedLogger"], cmd: str
|
||
) -> tuple[str, bool, bool, bool, float, dict[str, Any], list[str]]:
|
||
areq = ""
|
||
chk = False
|
||
fork = False
|
||
jtxt = False
|
||
wait = 0.0
|
||
tout = 0.0
|
||
kill = "t"
|
||
cap = 0
|
||
ocmd = cmd
|
||
while "," in cmd[:6]:
|
||
arg, cmd = cmd.split(",", 1)
|
||
if arg == "c":
|
||
chk = True
|
||
elif arg == "f":
|
||
fork = True
|
||
elif arg == "j":
|
||
jtxt = True
|
||
elif arg.startswith("w"):
|
||
wait = float(arg[1:])
|
||
elif arg.startswith("t"):
|
||
tout = float(arg[1:])
|
||
elif arg.startswith("c"):
|
||
cap = int(arg[1:]) # 0=none 1=stdout 2=stderr 3=both
|
||
elif arg.startswith("k"):
|
||
kill = arg[1:] # [t]ree [m]ain [n]one
|
||
elif arg.startswith("a"):
|
||
areq = arg[1:] # required perms
|
||
elif arg.startswith("i"):
|
||
pass
|
||
elif not arg:
|
||
break
|
||
else:
|
||
t = "hook: invalid flag {} in {}"
|
||
(log or print)(t.format(arg, ocmd))
|
||
|
||
env = os.environ.copy()
|
||
try:
|
||
if EXE:
|
||
raise Exception()
|
||
|
||
pypath = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||
zsl = [str(pypath)] + [str(x) for x in sys.path if x]
|
||
pypath = str(os.pathsep.join(zsl))
|
||
env["PYTHONPATH"] = pypath
|
||
except:
|
||
if not EXE:
|
||
raise
|
||
|
||
sp_ka = {
|
||
"env": env,
|
||
"nice": True,
|
||
"oom": 300,
|
||
"timeout": tout,
|
||
"kill": kill,
|
||
"capture": cap,
|
||
}
|
||
|
||
argv = cmd.split(",") if "," in cmd else [cmd]
|
||
|
||
argv[0] = os.path.expandvars(os.path.expanduser(argv[0]))
|
||
|
||
return areq, chk, fork, jtxt, wait, sp_ka, argv
|
||
|
||
|
||
def runihook(
|
||
log: Optional["NamedLogger"],
|
||
cmd: str,
|
||
vol: "VFS",
|
||
ups: list[tuple[str, int, int, str, str, str, int]],
|
||
) -> bool:
|
||
_, chk, fork, jtxt, wait, sp_ka, acmd = _parsehook(log, cmd)
|
||
bcmd = [sfsenc(x) for x in acmd]
|
||
if acmd[0].endswith(".py"):
|
||
bcmd = [sfsenc(pybin)] + bcmd
|
||
|
||
vps = [vjoin(*list(s3dec(x[3], x[4]))) for x in ups]
|
||
aps = [djoin(vol.realpath, x) for x in vps]
|
||
if jtxt:
|
||
# 0w 1mt 2sz 3rd 4fn 5ip 6at
|
||
ja = [
|
||
{
|
||
"ap": uncify(ap), # utf8 for json
|
||
"vp": vp,
|
||
"wark": x[0][:16],
|
||
"mt": x[1],
|
||
"sz": x[2],
|
||
"ip": x[5],
|
||
"at": x[6],
|
||
}
|
||
for x, vp, ap in zip(ups, vps, aps)
|
||
]
|
||
sp_ka["sin"] = json.dumps(ja).encode("utf-8", "replace")
|
||
else:
|
||
sp_ka["sin"] = b"\n".join(fsenc(x) for x in aps)
|
||
|
||
t0 = time.time()
|
||
if fork:
|
||
Daemon(runcmd, cmd, bcmd, ka=sp_ka)
|
||
else:
|
||
rc, v, err = runcmd(bcmd, **sp_ka) # type: ignore
|
||
if chk and rc:
|
||
retchk(rc, bcmd, err, log, 5)
|
||
return False
|
||
|
||
wait -= time.time() - t0
|
||
if wait > 0:
|
||
time.sleep(wait)
|
||
|
||
return True
|
||
|
||
|
||
def _runhook(
|
||
log: Optional["NamedLogger"],
|
||
src: str,
|
||
cmd: str,
|
||
ap: str,
|
||
vp: str,
|
||
host: str,
|
||
uname: str,
|
||
perms: str,
|
||
mt: float,
|
||
sz: int,
|
||
ip: str,
|
||
at: float,
|
||
txt: str,
|
||
) -> dict[str, Any]:
|
||
ret = {"rc": 0}
|
||
areq, chk, fork, jtxt, wait, sp_ka, acmd = _parsehook(log, cmd)
|
||
if areq:
|
||
for ch in areq:
|
||
if ch not in perms:
|
||
t = "user %s not allowed to run hook %s; need perms %s, have %s"
|
||
if log:
|
||
log(t % (uname, cmd, areq, perms))
|
||
return ret # fallthrough to next hook
|
||
if jtxt:
|
||
ja = {
|
||
"ap": ap,
|
||
"vp": vp,
|
||
"mt": mt,
|
||
"sz": sz,
|
||
"ip": ip,
|
||
"at": at or time.time(),
|
||
"host": host,
|
||
"user": uname,
|
||
"perms": perms,
|
||
"src": src,
|
||
"txt": txt,
|
||
}
|
||
arg = json.dumps(ja)
|
||
else:
|
||
arg = txt or ap
|
||
|
||
acmd += [arg]
|
||
if acmd[0].endswith(".py"):
|
||
acmd = [pybin] + acmd
|
||
|
||
bcmd = [fsenc(x) if x == ap else sfsenc(x) for x in acmd]
|
||
|
||
t0 = time.time()
|
||
if fork:
|
||
Daemon(runcmd, cmd, [bcmd], ka=sp_ka)
|
||
else:
|
||
rc, v, err = runcmd(bcmd, **sp_ka) # type: ignore
|
||
if chk and rc:
|
||
ret["rc"] = rc
|
||
retchk(rc, bcmd, err, log, 5)
|
||
else:
|
||
try:
|
||
ret = json.loads(v)
|
||
except:
|
||
ret = {}
|
||
|
||
try:
|
||
if "stdout" not in ret:
|
||
ret["stdout"] = v
|
||
if "rc" not in ret:
|
||
ret["rc"] = rc
|
||
except:
|
||
ret = {"rc": rc, "stdout": v}
|
||
|
||
wait -= time.time() - t0
|
||
if wait > 0:
|
||
time.sleep(wait)
|
||
|
||
return ret
|
||
|
||
|
||
def runhook(
|
||
log: Optional["NamedLogger"],
|
||
broker: Optional["BrokerCli"],
|
||
up2k: Optional["Up2k"],
|
||
src: str,
|
||
cmds: list[str],
|
||
ap: str,
|
||
vp: str,
|
||
host: str,
|
||
uname: str,
|
||
perms: str,
|
||
mt: float,
|
||
sz: int,
|
||
ip: str,
|
||
at: float,
|
||
txt: str,
|
||
) -> dict[str, Any]:
|
||
assert broker or up2k # !rm
|
||
args = (broker or up2k).args
|
||
vp = vp.replace("\\", "/")
|
||
ret = {"rc": 0}
|
||
for cmd in cmds:
|
||
try:
|
||
hr = _runhook(
|
||
log, src, cmd, ap, vp, host, uname, perms, mt, sz, ip, at, txt
|
||
)
|
||
if log and args.hook_v:
|
||
log("hook(%s) [%s] => \033[32m%s" % (src, cmd, hr), 6)
|
||
if not hr:
|
||
return {}
|
||
for k, v in hr.items():
|
||
if k in ("idx", "del") and v:
|
||
if broker:
|
||
broker.say("up2k.hook_fx", k, v, vp)
|
||
else:
|
||
up2k.fx_backlog.append((k, v, vp))
|
||
elif k == "reloc" and v:
|
||
# idk, just take the last one ig
|
||
ret["reloc"] = v
|
||
elif k in ret:
|
||
if k == "rc" and v:
|
||
ret[k] = v
|
||
else:
|
||
ret[k] = v
|
||
except Exception as ex:
|
||
(log or print)("hook: {}".format(ex))
|
||
if ",c," in "," + cmd:
|
||
return {}
|
||
break
|
||
|
||
return ret
|
||
|
||
|
||
def loadpy(ap: str, hot: bool) -> Any:
|
||
"""
|
||
a nice can of worms capable of causing all sorts of bugs
|
||
depending on what other inconveniently named files happen
|
||
to be in the same folder
|
||
"""
|
||
ap = os.path.expandvars(os.path.expanduser(ap))
|
||
mdir, mfile = os.path.split(absreal(ap))
|
||
mname = mfile.rsplit(".", 1)[0]
|
||
sys.path.insert(0, mdir)
|
||
|
||
if PY2:
|
||
mod = __import__(mname)
|
||
if hot:
|
||
reload(mod) # type: ignore
|
||
else:
|
||
import importlib
|
||
|
||
mod = importlib.import_module(mname)
|
||
if hot:
|
||
importlib.reload(mod)
|
||
|
||
sys.path.remove(mdir)
|
||
return mod
|
||
|
||
|
||
def gzip_orig_sz(fn: str) -> int:
|
||
with open(fsenc(fn), "rb") as f:
|
||
return gzip_file_orig_sz(f)
|
||
|
||
|
||
def gzip_file_orig_sz(f) -> int:
|
||
start = f.tell()
|
||
f.seek(-4, 2)
|
||
rv = f.read(4)
|
||
f.seek(start, 0)
|
||
return sunpack(b"I", rv)[0] # type: ignore
|
||
|
||
|
||
def align_tab(lines: list[str]) -> list[str]:
|
||
rows = []
|
||
ncols = 0
|
||
for ln in lines:
|
||
row = [x for x in ln.split(" ") if x]
|
||
ncols = max(ncols, len(row))
|
||
rows.append(row)
|
||
|
||
lens = [0] * ncols
|
||
for row in rows:
|
||
for n, col in enumerate(row):
|
||
lens[n] = max(lens[n], len(col))
|
||
|
||
return ["".join(x.ljust(y + 2) for x, y in zip(row, lens)) for row in rows]
|
||
|
||
|
||
def visual_length(txt: str) -> int:
|
||
# from r0c
|
||
eoc = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||
clen = 0
|
||
pend = None
|
||
counting = True
|
||
for ch in txt:
|
||
|
||
# escape sequences can never contain ESC;
|
||
# treat pend as regular text if so
|
||
if ch == "\033" and pend:
|
||
clen += len(pend)
|
||
counting = True
|
||
pend = None
|
||
|
||
if not counting:
|
||
if ch in eoc:
|
||
counting = True
|
||
else:
|
||
if pend:
|
||
pend += ch
|
||
if pend.startswith("\033["):
|
||
counting = False
|
||
else:
|
||
clen += len(pend)
|
||
counting = True
|
||
pend = None
|
||
else:
|
||
if ch == "\033":
|
||
pend = "%s" % (ch,)
|
||
else:
|
||
co = ord(ch)
|
||
# the safe parts of latin1 and cp437 (no greek stuff)
|
||
if (
|
||
co < 0x100 # ascii + lower half of latin1
|
||
or (co >= 0x2500 and co <= 0x25A0) # box drawings
|
||
or (co >= 0x2800 and co <= 0x28FF) # braille
|
||
):
|
||
clen += 1
|
||
else:
|
||
# assume moonrunes or other double-width
|
||
clen += 2
|
||
return clen
|
||
|
||
|
||
def wrap(txt: str, maxlen: int, maxlen2: int) -> list[str]:
|
||
# from r0c
|
||
words = re.sub(r"([, ])", r"\1\n", txt.rstrip()).split("\n")
|
||
pad = maxlen - maxlen2
|
||
ret = []
|
||
for word in words:
|
||
if len(word) * 2 < maxlen or visual_length(word) < maxlen:
|
||
ret.append(word)
|
||
else:
|
||
while visual_length(word) >= maxlen:
|
||
ret.append(word[: maxlen - 1] + "-")
|
||
word = word[maxlen - 1 :]
|
||
if word:
|
||
ret.append(word)
|
||
|
||
words = ret
|
||
ret = []
|
||
ln = ""
|
||
spent = 0
|
||
for word in words:
|
||
wl = visual_length(word)
|
||
if spent + wl > maxlen:
|
||
ret.append(ln)
|
||
maxlen = maxlen2
|
||
spent = 0
|
||
ln = " " * pad
|
||
ln += word
|
||
spent += wl
|
||
if ln:
|
||
ret.append(ln)
|
||
|
||
return ret
|
||
|
||
|
||
def termsize() -> tuple[int, int]:
|
||
# from hashwalk
|
||
env = os.environ
|
||
|
||
def ioctl_GWINSZ(fd: int) -> Optional[tuple[int, int]]:
|
||
try:
|
||
cr = sunpack(b"hh", fcntl.ioctl(fd, termios.TIOCGWINSZ, b"AAAA"))
|
||
return cr[::-1]
|
||
except:
|
||
return None
|
||
|
||
cr = ioctl_GWINSZ(0) or ioctl_GWINSZ(1) or ioctl_GWINSZ(2)
|
||
if not cr:
|
||
try:
|
||
fd = os.open(os.ctermid(), os.O_RDONLY)
|
||
cr = ioctl_GWINSZ(fd)
|
||
os.close(fd)
|
||
except:
|
||
pass
|
||
|
||
try:
|
||
return cr or (int(env["COLUMNS"]), int(env["LINES"]))
|
||
except:
|
||
return 80, 25
|
||
|
||
|
||
def hidedir(dp) -> None:
|
||
if ANYWIN:
|
||
try:
|
||
assert ctypes # type: ignore # !rm
|
||
k32 = ctypes.WinDLL("kernel32")
|
||
attrs = k32.GetFileAttributesW(dp)
|
||
if attrs >= 0:
|
||
k32.SetFileAttributesW(dp, attrs | 2)
|
||
except:
|
||
pass
|
||
|
||
|
||
try:
|
||
if sys.version_info < (3, 10):
|
||
# py3.8 doesn't have .files
|
||
# py3.9 has broken .is_file
|
||
raise ImportError()
|
||
import importlib.resources as impresources
|
||
except ImportError:
|
||
try:
|
||
import importlib_resources as impresources
|
||
except ImportError:
|
||
impresources = None
|
||
try:
|
||
if sys.version_info > (3, 10):
|
||
raise ImportError()
|
||
import pkg_resources
|
||
except ImportError:
|
||
pkg_resources = None
|
||
|
||
|
||
def _pkg_resource_exists(pkg: str, name: str) -> bool:
|
||
if not pkg_resources:
|
||
return False
|
||
try:
|
||
return pkg_resources.resource_exists(pkg, name)
|
||
except NotImplementedError:
|
||
return False
|
||
|
||
|
||
def stat_resource(E: EnvParams, name: str):
|
||
path = os.path.join(E.mod, name)
|
||
if os.path.exists(path):
|
||
return os.stat(fsenc(path))
|
||
return None
|
||
|
||
|
||
def _find_impresource(pkg: types.ModuleType, name: str):
|
||
assert impresources # !rm
|
||
try:
|
||
files = impresources.files(pkg)
|
||
except ImportError:
|
||
return None
|
||
|
||
return files.joinpath(name)
|
||
|
||
|
||
_rescache_has = {}
|
||
|
||
|
||
def _has_resource(name: str):
|
||
try:
|
||
return _rescache_has[name]
|
||
except:
|
||
pass
|
||
|
||
if len(_rescache_has) > 999:
|
||
_rescache_has.clear()
|
||
|
||
assert __package__ # !rm
|
||
pkg = sys.modules[__package__]
|
||
|
||
if impresources:
|
||
res = _find_impresource(pkg, name)
|
||
if res and res.is_file():
|
||
_rescache_has[name] = True
|
||
return True
|
||
|
||
if pkg_resources:
|
||
if _pkg_resource_exists(pkg.__name__, name):
|
||
_rescache_has[name] = True
|
||
return True
|
||
|
||
_rescache_has[name] = False
|
||
return False
|
||
|
||
|
||
def has_resource(E: EnvParams, name: str):
|
||
return _has_resource(name) or os.path.exists(os.path.join(E.mod, name))
|
||
|
||
|
||
def load_resource(E: EnvParams, name: str, mode="rb") -> IO[bytes]:
|
||
enc = None if "b" in mode else "utf-8"
|
||
|
||
if impresources:
|
||
assert __package__ # !rm
|
||
res = _find_impresource(sys.modules[__package__], name)
|
||
if res and res.is_file():
|
||
if enc:
|
||
return res.open(mode, encoding=enc)
|
||
else:
|
||
# throws if encoding= is mentioned at all
|
||
return res.open(mode)
|
||
|
||
if pkg_resources:
|
||
assert __package__ # !rm
|
||
pkg = sys.modules[__package__]
|
||
if _pkg_resource_exists(pkg.__name__, name):
|
||
stream = pkg_resources.resource_stream(pkg.__name__, name)
|
||
if enc:
|
||
stream = codecs.getreader(enc)(stream)
|
||
return stream
|
||
|
||
return open(os.path.join(E.mod, name), mode, encoding=enc)
|
||
|
||
|
||
class Pebkac(Exception):
|
||
def __init__(
|
||
self, code: int, msg: Optional[str] = None, log: Optional[str] = None
|
||
) -> None:
|
||
super(Pebkac, self).__init__(msg or HTTPCODE[code])
|
||
self.code = code
|
||
self.log = log
|
||
|
||
def __repr__(self) -> str:
|
||
return "Pebkac({}, {})".format(self.code, repr(self.args))
|
||
|
||
|
||
class WrongPostKey(Pebkac):
|
||
def __init__(
|
||
self,
|
||
expected: str,
|
||
got: str,
|
||
fname: Optional[str],
|
||
datagen: Generator[bytes, None, None],
|
||
) -> None:
|
||
msg = 'expected field "{}", got "{}"'.format(expected, got)
|
||
super(WrongPostKey, self).__init__(422, msg)
|
||
|
||
self.expected = expected
|
||
self.got = got
|
||
self.fname = fname
|
||
self.datagen = datagen
|
||
|
||
|
||
_: Any = (mp, BytesIO, quote, unquote, SQLITE_VER, JINJA_VER, PYFTPD_VER, PARTFTPY_VER)
|
||
__all__ = [
|
||
"mp",
|
||
"BytesIO",
|
||
"quote",
|
||
"unquote",
|
||
"SQLITE_VER",
|
||
"JINJA_VER",
|
||
"PYFTPD_VER",
|
||
"PARTFTPY_VER",
|
||
]
|