From 5c7debd9005fc1ee16e7ec5d1ef3f9022dd2c0ba Mon Sep 17 00:00:00 2001 From: ed Date: Sat, 17 Jul 2021 04:15:07 +0200 Subject: [PATCH] improve signal handling + emit sd-notify on start --- contrib/systemd/copyparty.service | 1 + copyparty/broker_mpw.py | 5 ++- copyparty/svchub.py | 74 +++++++++++++++++++++++++++---- tests/util.py | 2 +- 4 files changed, 71 insertions(+), 11 deletions(-) diff --git a/contrib/systemd/copyparty.service b/contrib/systemd/copyparty.service index d966e9ce..52fe28fb 100644 --- a/contrib/systemd/copyparty.service +++ b/contrib/systemd/copyparty.service @@ -12,6 +12,7 @@ Description=copyparty file server [Service] +Type=notify ExecStart=/usr/bin/python3 /usr/local/bin/copyparty-sfx.py -q -v /mnt::a ExecStartPre=/bin/bash -c 'mkdir -p /run/tmpfiles.d/ && echo "x /tmp/pe-copyparty*" > /run/tmpfiles.d/copyparty.conf' diff --git a/copyparty/broker_mpw.py b/copyparty/broker_mpw.py index c5c63bcb..658ce56b 100644 --- a/copyparty/broker_mpw.py +++ b/copyparty/broker_mpw.py @@ -29,7 +29,8 @@ class MpWorker(object): # we inherited signal_handler from parent, # replace it with something harmless if not FAKE_MP: - signal.signal(signal.SIGINT, self.signal_handler) + for sig in [signal.SIGINT, signal.SIGTERM]: + signal.signal(sig, self.signal_handler) # starting to look like a good idea self.asrv = AuthSrv(args, None, False) @@ -44,7 +45,7 @@ class MpWorker(object): thr.start() thr.join() - def signal_handler(self, signal, frame): + def signal_handler(self, sig, frame): # print('k') pass diff --git a/copyparty/svchub.py b/copyparty/svchub.py index 5c35b553..61b6e024 100644 --- a/copyparty/svchub.py +++ b/copyparty/svchub.py @@ -6,12 +6,15 @@ import os import sys import time import shlex +import string +import signal +import socket import threading from datetime import datetime, timedelta import calendar -from .__init__ import E, PY2, WINDOWS, MACOS, VT100 -from .util import mp, start_log_thrs, start_stackmon +from .__init__ import E, PY2, WINDOWS, MACOS, VT100, unicode +from .util import mp, start_log_thrs, start_stackmon, min_ex from .authsrv import AuthSrv from .tcpsrv import TcpSrv from .up2k import Up2k @@ -33,6 +36,9 @@ class SvcHub(object): self.args = args self.argv = argv self.logf = None + self.stop_req = False + self.stopping = False + self.stop_cond = threading.Condition() self.ansi_re = re.compile("\033\\[[^m]*m") self.log_mutex = threading.Lock() @@ -127,16 +133,49 @@ class SvcHub(object): print(msg, end="") def run(self): - thr = threading.Thread(target=self.tcpsrv.run, name="svchub-main") + self.tcpsrv.run() + + thr = threading.Thread(target=self.sd_notify, name="sd-notify") thr.daemon = True thr.start() - # winxp/py2.7 support: thr.join() kills signals - try: - while True: - time.sleep(9001) + thr = threading.Thread(target=self.stop_thr, name="svchub-sig") + thr.daemon = True + thr.start() - except KeyboardInterrupt: + for sig in [signal.SIGINT, signal.SIGTERM]: + signal.signal(sig, self.signal_handler) + + try: + while not self.stop_req: + time.sleep(9001) + except: + pass + + self.shutdown() + + def stop_thr(self): + while not self.stop_req: + with self.stop_cond: + self.stop_cond.wait(9001) + + self.shutdown() + + def signal_handler(self): + if self.stopping: + return + + self.stop_req = True + with self.stop_cond: + self.stop_cond.notify_all() + + def shutdown(self): + if self.stopping: + return + + self.stopping = True + self.stop_req = True + try: with self.log_mutex: print("OPYTHAT") @@ -268,3 +307,22 @@ class SvcHub(object): else: self.log("svchub", err) return False + + def sd_notify(self): + try: + addr = os.getenv("NOTIFY_SOCKET") + if not addr: + return + + addr = unicode(addr) + if addr.startswith("@"): + addr = "\0" + addr[1:] + + m = "".join(x for x in addr if x in string.printable) + self.log("sd_notify", m) + + sck = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + sck.connect(addr) + sck.sendall(b"READY=1") + except: + self.log("sd_notify", min_ex()) diff --git a/tests/util.py b/tests/util.py index 490112cc..6ab3c6fe 100644 --- a/tests/util.py +++ b/tests/util.py @@ -119,7 +119,7 @@ class VHttpConn(object): self.addr = ("127.0.0.1", "42069") self.args = args self.asrv = asrv - self.is_mp = False + self.nid = None self.log_func = log self.log_src = "a" self.lf_url = None