From 87c60a1ec91bf02673b7e69ea44fadfe3e148bd9 Mon Sep 17 00:00:00 2001 From: ed Date: Thu, 9 May 2024 22:28:16 +0000 Subject: [PATCH] ensure OS signals hit main-thread as intended; use sigmasks to block SIGINT, SIGTERM, SIGUSR1 from all other threads also initiate shutdown by calling sighandler directly, in case this misses anything and that is still unreliable (discovered by `--exit=idx` being noop once in a blue moon) --- bin/u2c.py | 19 ++++++++++++++----- copyparty/__main__.py | 17 +++++++++++++++++ copyparty/broker_mp.py | 7 ++----- copyparty/httpsrv.py | 5 +---- copyparty/smbd.py | 2 +- copyparty/svchub.py | 5 +++-- copyparty/up2k.py | 3 +-- copyparty/util.py | 26 +++++++++++++++++++++----- scripts/sfx.py | 17 ++--------------- scripts/ziploader.py | 5 +---- 10 files changed, 63 insertions(+), 43 deletions(-) diff --git a/bin/u2c.py b/bin/u2c.py index 4a4aef9f..dc278280 100755 --- a/bin/u2c.py +++ b/bin/u2c.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 from __future__ import print_function, unicode_literals -S_VERSION = "1.16" -S_BUILD_DT = "2024-04-20" +S_VERSION = "1.17" +S_BUILD_DT = "2024-05-09" """ u2c.py: upload to copyparty @@ -79,12 +79,21 @@ req_ses = requests.Session() class Daemon(threading.Thread): - def __init__(self, target, name=None, a=None): - # type: (Any, Any, Any) -> None - threading.Thread.__init__(self, target=target, args=a or (), name=name) + def __init__(self, target, name = None, a = None): + threading.Thread.__init__(self, name=name) + self.a = a or () + self.fun = target self.daemon = True self.start() + def run(self): + try: + signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGINT, signal.SIGTERM]) + except: + pass + + self.fun(*self.a) + class File(object): """an up2k upload task; represents a single file""" diff --git a/copyparty/__main__.py b/copyparty/__main__.py index 5df76aee..902c322a 100644 --- a/copyparty/__main__.py +++ b/copyparty/__main__.py @@ -49,6 +49,7 @@ from .util import ( PYFTPD_VER, SQLITE_VER, UNPLICATIONS, + Daemon, align_tab, ansi_re, dedent, @@ -471,6 +472,16 @@ def disable_quickedit() -> None: cmode(True, mode | 4) +def sfx_tpoke(top: str): + files = [os.path.join(dp, p) for dp, dd, df in os.walk(top) for p in dd + df] + while True: + t = int(time.time()) + for f in [top] + files: + os.utime(f, (t, t)) + + time.sleep(78123) + + def showlic() -> None: p = os.path.join(E.mod, "res", "COPYING.txt") if not os.path.exists(p): @@ -1454,6 +1465,12 @@ def main(argv: Optional[list[str]] = None, rsrc: Optional[str] = None) -> None: if EXE: print("pybin: {}\n".format(pybin), end="") + for n, zs in enumerate(argv): + if zs.startswith("--sfx-tpoke="): + Daemon(sfx_tpoke, "sfx-tpoke", (zs.split("=", 1)[1],)) + argv.pop(n) + break + ensure_locale() ensure_webdeps() diff --git a/copyparty/broker_mp.py b/copyparty/broker_mp.py index 848b07ee..b09f6ce3 100644 --- a/copyparty/broker_mp.py +++ b/copyparty/broker_mp.py @@ -57,11 +57,8 @@ class BrokerMp(object): def shutdown(self) -> None: self.log("broker", "shutting down") for n, proc in enumerate(self.procs): - thr = threading.Thread( - target=proc.q_pend.put((0, "shutdown", [])), - name="mp-shutdown-{}-{}".format(n, len(self.procs)), - ) - thr.start() + name = "mp-shut-%d-%d" % (n, len(self.procs)) + Daemon(proc.q_pend.put, name, ((0, "shutdown", []),)) with self.mutex: procs = self.procs diff --git a/copyparty/httpsrv.py b/copyparty/httpsrv.py index adef2d4a..c0e69ee5 100644 --- a/copyparty/httpsrv.py +++ b/copyparty/httpsrv.py @@ -266,10 +266,7 @@ class HttpSrv(object): msg = "subscribed @ {}:{} f{} p{}".format(hip, port, fno, os.getpid()) self.log(self.name, msg) - def fun() -> None: - self.broker.say("cb_httpsrv_up") - - threading.Thread(target=fun, name="sig-hsrv-up1").start() + Daemon(self.broker.say, "sig-hsrv-up1", ("cb_httpsrv_up",)) while not self.stopping: if self.args.log_conn: diff --git a/copyparty/smbd.py b/copyparty/smbd.py index 979c11df..6a096265 100644 --- a/copyparty/smbd.py +++ b/copyparty/smbd.py @@ -127,7 +127,7 @@ class SMB(object): self.log("smb", msg, c) def start(self) -> None: - Daemon(self.srv.start) + Daemon(self.srv.start, "smbd") def _auth_cb(self, *a, **ka): debug("auth-result: %s %s", a, ka) diff --git a/copyparty/svchub.py b/copyparty/svchub.py index 0a7083ce..e69c3d42 100644 --- a/copyparty/svchub.py +++ b/copyparty/svchub.py @@ -293,13 +293,14 @@ class SvcHub(object): from .ftpd import Ftpd self.ftpd: Optional[Ftpd] = None - Daemon(self.start_ftpd, "start_ftpd") zms += "f" if args.ftp else "F" if args.tftp: from .tftpd import Tftpd self.tftpd: Optional[Tftpd] = None + + if args.ftp or args.ftps or args.tftp: Daemon(self.start_ftpd, "start_tftpd") if args.smb: @@ -388,7 +389,7 @@ class SvcHub(object): self.sigterm() def sigterm(self) -> None: - os.kill(os.getpid(), signal.SIGTERM) + self.signal_handler(signal.SIGTERM, None) def cb_httpsrv_up(self) -> None: self.httpsrv_up += 1 diff --git a/copyparty/up2k.py b/copyparty/up2k.py index f1ec0009..6786fcd9 100644 --- a/copyparty/up2k.py +++ b/copyparty/up2k.py @@ -10,7 +10,6 @@ import math import os import re import shutil -import signal import stat import subprocess as sp import tempfile @@ -1659,7 +1658,7 @@ class Up2k(object): if e2vp and rewark: self.hub.retcode = 1 - os.kill(os.getpid(), signal.SIGTERM) + Daemon(self.hub.sigterm) raise Exception("{} files have incorrect hashes".format(len(rewark))) if not e2vu or not rewark: diff --git a/copyparty/util.py b/copyparty/util.py index 4e344f88..5f4be32c 100644 --- a/copyparty/util.py +++ b/copyparty/util.py @@ -463,13 +463,20 @@ class Daemon(threading.Thread): r: bool = True, ka: Optional[dict[Any, Any]] = None, ) -> None: - threading.Thread.__init__( - self, target=target, name=name, args=a or (), kwargs=ka - ) + threading.Thread.__init__(self, name=name) + self.a = a or () + self.ka = ka or {} + self.fun = target self.daemon = True if r: self.start() + def run(self): + if not ANYWIN and not PY2: + signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGINT, signal.SIGTERM, signal.SIGUSR1]) + + self.fun(*self.a, **self.ka) + class Netdev(object): def __init__(self, ip: str, idx: int, name: str, desc: str): @@ -864,6 +871,7 @@ class ProgressPrinter(threading.Thread): self.start() def run(self) -> None: + sigblock() tp = 0 msg = None no_stdout = self.args.q @@ -1308,6 +1316,13 @@ def log_thrs(log: Callable[[str, str, int], None], ival: float, name: str) -> No log(name, "\033[0m \033[33m".join(tv), 3) +def sigblock(): + if ANYWIN or PY2: + return + + signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGINT, signal.SIGTERM, signal.SIGUSR1]) + + def vol_san(vols: list["VFS"], txt: bytes) -> bytes: txt0 = txt for vol in vols: @@ -1329,10 +1344,11 @@ def vol_san(vols: list["VFS"], txt: bytes) -> bytes: def min_ex(max_lines: int = 8, reverse: bool = False) -> str: et, ev, tb = sys.exc_info() - stb = traceback.extract_tb(tb) + stb = traceback.extract_tb(tb) if tb else traceback.extract_stack()[:-1] fmt = "%s @ %d <%s>: %s" ex = [fmt % (fp.split(os.sep)[-1], ln, fun, txt) for fp, ln, fun, txt in stb] - ex.append("[%s] %s" % (et.__name__ if et else "(anonymous)", ev)) + if et or ev or tb: + ex.append("[%s] %s" % (et.__name__ if et else "(anonymous)", ev)) return "\n".join(ex[-max_lines:][:: -1 if reverse else 1]) diff --git a/scripts/sfx.py b/scripts/sfx.py index 0203571c..a1b6477f 100644 --- a/scripts/sfx.py +++ b/scripts/sfx.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # coding: latin-1 from __future__ import print_function, unicode_literals -import re, os, sys, time, shutil, signal, threading, tarfile, hashlib, platform, tempfile, traceback +import re, os, sys, time, shutil, signal, tarfile, hashlib, platform, tempfile, traceback import subprocess as sp @@ -368,17 +368,6 @@ def get_payload(): p = a -def utime(top): - # avoid cleaners - files = [os.path.join(dp, p) for dp, dd, df in os.walk(top) for p in dd + df] - while True: - t = int(time.time()) - for f in [top] + files: - os.utime(f, (t, t)) - - time.sleep(78123) - - def confirm(rv): msg() msg("retcode", rv if rv else traceback.format_exc()) @@ -398,9 +387,7 @@ def run(tmp, j2, ftp): msg("sfxdir:", tmp) msg() - t = threading.Thread(target=utime, args=(tmp,), name="utime") - t.daemon = True - t.start() + sys.argv.append("--sfx-tpoke=" + tmp) ld = (("", ""), (j2, "j2"), (ftp, "ftp"), (not PY2, "py2"), (PY37, "py37")) ld = [os.path.join(tmp, b) for a, b in ld if not a] diff --git a/scripts/ziploader.py b/scripts/ziploader.py index bbf4152e..0f41f29d 100644 --- a/scripts/ziploader.py +++ b/scripts/ziploader.py @@ -6,7 +6,6 @@ import platform import sys import tarfile import tempfile -import threading import time import traceback @@ -80,9 +79,7 @@ def run(): msg(" rsrc dir:", rsrc) msg() - t = threading.Thread(target=utime, args=(rsrc,), name="utime") - t.daemon = True - t.start() + sys.argv.append("--sfx-tpoke=" + rsrc) cm(rsrc=rsrc)