improve signal handling + emit sd-notify on start

This commit is contained in:
ed 2021-07-17 04:15:07 +02:00
parent 7fa5b23ce3
commit 5c7debd900
4 changed files with 71 additions and 11 deletions

View file

@ -12,6 +12,7 @@
Description=copyparty file server Description=copyparty file server
[Service] [Service]
Type=notify
ExecStart=/usr/bin/python3 /usr/local/bin/copyparty-sfx.py -q -v /mnt::a ExecStart=/usr/bin/python3 /usr/local/bin/copyparty-sfx.py -q -v /mnt::a
ExecStartPre=/bin/bash -c 'mkdir -p /run/tmpfiles.d/ && echo "x /tmp/pe-copyparty*" > /run/tmpfiles.d/copyparty.conf' ExecStartPre=/bin/bash -c 'mkdir -p /run/tmpfiles.d/ && echo "x /tmp/pe-copyparty*" > /run/tmpfiles.d/copyparty.conf'

View file

@ -29,7 +29,8 @@ class MpWorker(object):
# we inherited signal_handler from parent, # we inherited signal_handler from parent,
# replace it with something harmless # replace it with something harmless
if not FAKE_MP: if not FAKE_MP:
signal.signal(signal.SIGINT, self.signal_handler) for sig in [signal.SIGINT, signal.SIGTERM]:
signal.signal(sig, self.signal_handler)
# starting to look like a good idea # starting to look like a good idea
self.asrv = AuthSrv(args, None, False) self.asrv = AuthSrv(args, None, False)
@ -44,7 +45,7 @@ class MpWorker(object):
thr.start() thr.start()
thr.join() thr.join()
def signal_handler(self, signal, frame): def signal_handler(self, sig, frame):
# print('k') # print('k')
pass pass

View file

@ -6,12 +6,15 @@ import os
import sys import sys
import time import time
import shlex import shlex
import string
import signal
import socket
import threading import threading
from datetime import datetime, timedelta from datetime import datetime, timedelta
import calendar import calendar
from .__init__ import E, PY2, WINDOWS, MACOS, VT100 from .__init__ import E, PY2, WINDOWS, MACOS, VT100, unicode
from .util import mp, start_log_thrs, start_stackmon from .util import mp, start_log_thrs, start_stackmon, min_ex
from .authsrv import AuthSrv from .authsrv import AuthSrv
from .tcpsrv import TcpSrv from .tcpsrv import TcpSrv
from .up2k import Up2k from .up2k import Up2k
@ -33,6 +36,9 @@ class SvcHub(object):
self.args = args self.args = args
self.argv = argv self.argv = argv
self.logf = None self.logf = None
self.stop_req = False
self.stopping = False
self.stop_cond = threading.Condition()
self.ansi_re = re.compile("\033\\[[^m]*m") self.ansi_re = re.compile("\033\\[[^m]*m")
self.log_mutex = threading.Lock() self.log_mutex = threading.Lock()
@ -127,16 +133,49 @@ class SvcHub(object):
print(msg, end="") print(msg, end="")
def run(self): def run(self):
thr = threading.Thread(target=self.tcpsrv.run, name="svchub-main") self.tcpsrv.run()
thr = threading.Thread(target=self.sd_notify, name="sd-notify")
thr.daemon = True thr.daemon = True
thr.start() thr.start()
# winxp/py2.7 support: thr.join() kills signals thr = threading.Thread(target=self.stop_thr, name="svchub-sig")
try: thr.daemon = True
while True: thr.start()
time.sleep(9001)
except KeyboardInterrupt: for sig in [signal.SIGINT, signal.SIGTERM]:
signal.signal(sig, self.signal_handler)
try:
while not self.stop_req:
time.sleep(9001)
except:
pass
self.shutdown()
def stop_thr(self):
while not self.stop_req:
with self.stop_cond:
self.stop_cond.wait(9001)
self.shutdown()
def signal_handler(self):
if self.stopping:
return
self.stop_req = True
with self.stop_cond:
self.stop_cond.notify_all()
def shutdown(self):
if self.stopping:
return
self.stopping = True
self.stop_req = True
try:
with self.log_mutex: with self.log_mutex:
print("OPYTHAT") print("OPYTHAT")
@ -268,3 +307,22 @@ class SvcHub(object):
else: else:
self.log("svchub", err) self.log("svchub", err)
return False return False
def sd_notify(self):
try:
addr = os.getenv("NOTIFY_SOCKET")
if not addr:
return
addr = unicode(addr)
if addr.startswith("@"):
addr = "\0" + addr[1:]
m = "".join(x for x in addr if x in string.printable)
self.log("sd_notify", m)
sck = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
sck.connect(addr)
sck.sendall(b"READY=1")
except:
self.log("sd_notify", min_ex())

View file

@ -119,7 +119,7 @@ class VHttpConn(object):
self.addr = ("127.0.0.1", "42069") self.addr = ("127.0.0.1", "42069")
self.args = args self.args = args
self.asrv = asrv self.asrv = asrv
self.is_mp = False self.nid = None
self.log_func = log self.log_func = log
self.log_src = "a" self.log_src = "a"
self.lf_url = None self.lf_url = None