From 8fef9e363ea1b07f53c5efda9ab91d0bc3a4f5b2 Mon Sep 17 00:00:00 2001 From: ed Date: Sun, 3 Jul 2022 04:57:15 +0200 Subject: [PATCH] recursive kill mtp on timeout --- copyparty/util.py | 80 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 79 insertions(+), 1 deletion(-) diff --git a/copyparty/util.py b/copyparty/util.py index a6408732..42af5d93 100644 --- a/copyparty/util.py +++ b/copyparty/util.py @@ -30,6 +30,12 @@ try: except: HAVE_SQLITE3 = False +try: + HAVE_PSUTIL = True + import psutil +except: + HAVE_PSUTIL = False + try: import types from collections.abc import Callable, Iterable @@ -1489,6 +1495,78 @@ def guess_mime(url: str, fallback: str = "application/octet-stream") -> str: return ret +def getalive(pids: list[int], pgid: int) -> list[int]: + alive = [] + for pid in pids: + if pgid: + try: + # check if still one of ours + if os.getpgid(pid) == pgid: + alive.append(pid) + except: + pass + else: + try: + # windows doesn't have pgroups; assume + psutil.Process(pid) + alive.append(pid) + except: + pass + + return alive + + +def killtree(root: int) -> None: + """still racy but i tried""" + try: + # limit the damage where possible (unixes) + pgid = os.getpgid(os.getpid()) + except: + pgid = 0 + + if HAVE_PSUTIL: + pids = [root] + parent = psutil.Process(root) + for child in parent.children(recursive=True): + pids.append(child.pid) + child.terminate() + parent.terminate() + parent = None + elif pgid: + # linux-only + pids = [] + chk = [root] + while chk: + pid = chk[0] + chk = chk[1:] + pids.append(pid) + _, t, _ = runcmd(["pgrep", "-P", str(pid)]) + chk += [int(x) for x in t.strip().split("\n") if x] + + pids = getalive(pids, pgid) # filter to our pgroup + for pid in pids: + os.kill(pid, signal.SIGTERM) + else: + # windows gets minimal effort sorry + os.kill(pid, signal.SIGTERM) + return + + if not pids: + return # yay + + for n in range(10): + time.sleep(0.1) + pids = getalive(pids, pgid) + if not pids or n > 3 and pids == [root]: + break + + for pid in pids: + try: + os.kill(pid, signal.SIGKILL) + except: + pass + + def runcmd( argv: Union[list[bytes], list[str]], timeout: Optional[int] = None, **ka: Any ) -> tuple[int, str, str]: @@ -1499,7 +1577,7 @@ def runcmd( try: stdout, stderr = p.communicate(timeout=timeout) except sp.TimeoutExpired: - p.kill() + killtree(p.pid) stdout, stderr = p.communicate() stdout = stdout.decode("utf-8", "replace")