ensure OS signals hit main-thread as intended;

use sigmasks to block SIGINT, SIGTERM, SIGUSR1 from all other threads

also initiate shutdown by calling sighandler directly,
in case this misses anything and that is still unreliable
(discovered by `--exit=idx` being noop once in a blue moon)
This commit is contained in:
ed 2024-05-09 22:28:16 +00:00
parent 2c92dab165
commit 87c60a1ec9
10 changed files with 63 additions and 43 deletions

View file

@ -1,8 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
S_VERSION = "1.16" S_VERSION = "1.17"
S_BUILD_DT = "2024-04-20" S_BUILD_DT = "2024-05-09"
""" """
u2c.py: upload to copyparty u2c.py: upload to copyparty
@ -79,12 +79,21 @@ req_ses = requests.Session()
class Daemon(threading.Thread): class Daemon(threading.Thread):
def __init__(self, target, name=None, a=None): def __init__(self, target, name = None, a = None):
# type: (Any, Any, Any) -> None threading.Thread.__init__(self, name=name)
threading.Thread.__init__(self, target=target, args=a or (), name=name) self.a = a or ()
self.fun = target
self.daemon = True self.daemon = True
self.start() self.start()
def run(self):
try:
signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGINT, signal.SIGTERM])
except:
pass
self.fun(*self.a)
class File(object): class File(object):
"""an up2k upload task; represents a single file""" """an up2k upload task; represents a single file"""

View file

@ -49,6 +49,7 @@ from .util import (
PYFTPD_VER, PYFTPD_VER,
SQLITE_VER, SQLITE_VER,
UNPLICATIONS, UNPLICATIONS,
Daemon,
align_tab, align_tab,
ansi_re, ansi_re,
dedent, dedent,
@ -471,6 +472,16 @@ def disable_quickedit() -> None:
cmode(True, mode | 4) cmode(True, mode | 4)
def sfx_tpoke(top: str):
files = [os.path.join(dp, p) for dp, dd, df in os.walk(top) for p in dd + df]
while True:
t = int(time.time())
for f in [top] + files:
os.utime(f, (t, t))
time.sleep(78123)
def showlic() -> None: def showlic() -> None:
p = os.path.join(E.mod, "res", "COPYING.txt") p = os.path.join(E.mod, "res", "COPYING.txt")
if not os.path.exists(p): if not os.path.exists(p):
@ -1454,6 +1465,12 @@ def main(argv: Optional[list[str]] = None, rsrc: Optional[str] = None) -> None:
if EXE: if EXE:
print("pybin: {}\n".format(pybin), end="") print("pybin: {}\n".format(pybin), end="")
for n, zs in enumerate(argv):
if zs.startswith("--sfx-tpoke="):
Daemon(sfx_tpoke, "sfx-tpoke", (zs.split("=", 1)[1],))
argv.pop(n)
break
ensure_locale() ensure_locale()
ensure_webdeps() ensure_webdeps()

View file

@ -57,11 +57,8 @@ class BrokerMp(object):
def shutdown(self) -> None: def shutdown(self) -> None:
self.log("broker", "shutting down") self.log("broker", "shutting down")
for n, proc in enumerate(self.procs): for n, proc in enumerate(self.procs):
thr = threading.Thread( name = "mp-shut-%d-%d" % (n, len(self.procs))
target=proc.q_pend.put((0, "shutdown", [])), Daemon(proc.q_pend.put, name, ((0, "shutdown", []),))
name="mp-shutdown-{}-{}".format(n, len(self.procs)),
)
thr.start()
with self.mutex: with self.mutex:
procs = self.procs procs = self.procs

View file

@ -266,10 +266,7 @@ class HttpSrv(object):
msg = "subscribed @ {}:{} f{} p{}".format(hip, port, fno, os.getpid()) msg = "subscribed @ {}:{} f{} p{}".format(hip, port, fno, os.getpid())
self.log(self.name, msg) self.log(self.name, msg)
def fun() -> None: Daemon(self.broker.say, "sig-hsrv-up1", ("cb_httpsrv_up",))
self.broker.say("cb_httpsrv_up")
threading.Thread(target=fun, name="sig-hsrv-up1").start()
while not self.stopping: while not self.stopping:
if self.args.log_conn: if self.args.log_conn:

View file

@ -127,7 +127,7 @@ class SMB(object):
self.log("smb", msg, c) self.log("smb", msg, c)
def start(self) -> None: def start(self) -> None:
Daemon(self.srv.start) Daemon(self.srv.start, "smbd")
def _auth_cb(self, *a, **ka): def _auth_cb(self, *a, **ka):
debug("auth-result: %s %s", a, ka) debug("auth-result: %s %s", a, ka)

View file

@ -293,13 +293,14 @@ class SvcHub(object):
from .ftpd import Ftpd from .ftpd import Ftpd
self.ftpd: Optional[Ftpd] = None self.ftpd: Optional[Ftpd] = None
Daemon(self.start_ftpd, "start_ftpd")
zms += "f" if args.ftp else "F" zms += "f" if args.ftp else "F"
if args.tftp: if args.tftp:
from .tftpd import Tftpd from .tftpd import Tftpd
self.tftpd: Optional[Tftpd] = None self.tftpd: Optional[Tftpd] = None
if args.ftp or args.ftps or args.tftp:
Daemon(self.start_ftpd, "start_tftpd") Daemon(self.start_ftpd, "start_tftpd")
if args.smb: if args.smb:
@ -388,7 +389,7 @@ class SvcHub(object):
self.sigterm() self.sigterm()
def sigterm(self) -> None: def sigterm(self) -> None:
os.kill(os.getpid(), signal.SIGTERM) self.signal_handler(signal.SIGTERM, None)
def cb_httpsrv_up(self) -> None: def cb_httpsrv_up(self) -> None:
self.httpsrv_up += 1 self.httpsrv_up += 1

View file

@ -10,7 +10,6 @@ import math
import os import os
import re import re
import shutil import shutil
import signal
import stat import stat
import subprocess as sp import subprocess as sp
import tempfile import tempfile
@ -1659,7 +1658,7 @@ class Up2k(object):
if e2vp and rewark: if e2vp and rewark:
self.hub.retcode = 1 self.hub.retcode = 1
os.kill(os.getpid(), signal.SIGTERM) Daemon(self.hub.sigterm)
raise Exception("{} files have incorrect hashes".format(len(rewark))) raise Exception("{} files have incorrect hashes".format(len(rewark)))
if not e2vu or not rewark: if not e2vu or not rewark:

View file

@ -463,13 +463,20 @@ class Daemon(threading.Thread):
r: bool = True, r: bool = True,
ka: Optional[dict[Any, Any]] = None, ka: Optional[dict[Any, Any]] = None,
) -> None: ) -> None:
threading.Thread.__init__( threading.Thread.__init__(self, name=name)
self, target=target, name=name, args=a or (), kwargs=ka self.a = a or ()
) self.ka = ka or {}
self.fun = target
self.daemon = True self.daemon = True
if r: if r:
self.start() self.start()
def run(self):
if not ANYWIN and not PY2:
signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGINT, signal.SIGTERM, signal.SIGUSR1])
self.fun(*self.a, **self.ka)
class Netdev(object): class Netdev(object):
def __init__(self, ip: str, idx: int, name: str, desc: str): def __init__(self, ip: str, idx: int, name: str, desc: str):
@ -864,6 +871,7 @@ class ProgressPrinter(threading.Thread):
self.start() self.start()
def run(self) -> None: def run(self) -> None:
sigblock()
tp = 0 tp = 0
msg = None msg = None
no_stdout = self.args.q no_stdout = self.args.q
@ -1308,6 +1316,13 @@ def log_thrs(log: Callable[[str, str, int], None], ival: float, name: str) -> No
log(name, "\033[0m \033[33m".join(tv), 3) log(name, "\033[0m \033[33m".join(tv), 3)
def sigblock():
if ANYWIN or PY2:
return
signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGINT, signal.SIGTERM, signal.SIGUSR1])
def vol_san(vols: list["VFS"], txt: bytes) -> bytes: def vol_san(vols: list["VFS"], txt: bytes) -> bytes:
txt0 = txt txt0 = txt
for vol in vols: for vol in vols:
@ -1329,10 +1344,11 @@ def vol_san(vols: list["VFS"], txt: bytes) -> bytes:
def min_ex(max_lines: int = 8, reverse: bool = False) -> str: def min_ex(max_lines: int = 8, reverse: bool = False) -> str:
et, ev, tb = sys.exc_info() et, ev, tb = sys.exc_info()
stb = traceback.extract_tb(tb) stb = traceback.extract_tb(tb) if tb else traceback.extract_stack()[:-1]
fmt = "%s @ %d <%s>: %s" fmt = "%s @ %d <%s>: %s"
ex = [fmt % (fp.split(os.sep)[-1], ln, fun, txt) for fp, ln, fun, txt in stb] ex = [fmt % (fp.split(os.sep)[-1], ln, fun, txt) for fp, ln, fun, txt in stb]
ex.append("[%s] %s" % (et.__name__ if et else "(anonymous)", ev)) if et or ev or tb:
ex.append("[%s] %s" % (et.__name__ if et else "(anonymous)", ev))
return "\n".join(ex[-max_lines:][:: -1 if reverse else 1]) return "\n".join(ex[-max_lines:][:: -1 if reverse else 1])

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# coding: latin-1 # coding: latin-1
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import re, os, sys, time, shutil, signal, threading, tarfile, hashlib, platform, tempfile, traceback import re, os, sys, time, shutil, signal, tarfile, hashlib, platform, tempfile, traceback
import subprocess as sp import subprocess as sp
@ -368,17 +368,6 @@ def get_payload():
p = a p = a
def utime(top):
# avoid cleaners
files = [os.path.join(dp, p) for dp, dd, df in os.walk(top) for p in dd + df]
while True:
t = int(time.time())
for f in [top] + files:
os.utime(f, (t, t))
time.sleep(78123)
def confirm(rv): def confirm(rv):
msg() msg()
msg("retcode", rv if rv else traceback.format_exc()) msg("retcode", rv if rv else traceback.format_exc())
@ -398,9 +387,7 @@ def run(tmp, j2, ftp):
msg("sfxdir:", tmp) msg("sfxdir:", tmp)
msg() msg()
t = threading.Thread(target=utime, args=(tmp,), name="utime") sys.argv.append("--sfx-tpoke=" + tmp)
t.daemon = True
t.start()
ld = (("", ""), (j2, "j2"), (ftp, "ftp"), (not PY2, "py2"), (PY37, "py37")) ld = (("", ""), (j2, "j2"), (ftp, "ftp"), (not PY2, "py2"), (PY37, "py37"))
ld = [os.path.join(tmp, b) for a, b in ld if not a] ld = [os.path.join(tmp, b) for a, b in ld if not a]

View file

@ -6,7 +6,6 @@ import platform
import sys import sys
import tarfile import tarfile
import tempfile import tempfile
import threading
import time import time
import traceback import traceback
@ -80,9 +79,7 @@ def run():
msg(" rsrc dir:", rsrc) msg(" rsrc dir:", rsrc)
msg() msg()
t = threading.Thread(target=utime, args=(rsrc,), name="utime") sys.argv.append("--sfx-tpoke=" + rsrc)
t.daemon = True
t.start()
cm(rsrc=rsrc) cm(rsrc=rsrc)