diff --git a/copyparty/__main__.py b/copyparty/__main__.py index 2d1c63ad..b15ec3e7 100644 --- a/copyparty/__main__.py +++ b/copyparty/__main__.py @@ -8,14 +8,11 @@ __copyright__ = 2019 __license__ = "MIT" __url__ = "https://github.com/9001/copyparty/" -import time import argparse -import threading from textwrap import dedent -import multiprocessing as mp from .__version__ import S_VERSION, S_BUILD_DT -from .tcpsrv import TcpSrv +from .svchub import SvcHub class RiceFormatter(argparse.HelpFormatter): @@ -38,13 +35,6 @@ class RiceFormatter(argparse.HelpFormatter): def main(): - try: - # support vscode debugger (bonus: same behavior as on windows) - mp.set_start_method("spawn", True) - except AttributeError: - # py2.7 probably - pass - ap = argparse.ArgumentParser( formatter_class=RiceFormatter, prog="copyparty", @@ -84,19 +74,7 @@ def main(): ap.add_argument("-nw", action="store_true", help="benchmark: disable writing") al = ap.parse_args() - tcpsrv = TcpSrv(al) - thr = threading.Thread(target=tcpsrv.run) - thr.daemon = True - thr.start() - - # winxp/py2.7 support: thr.join() kills signals - try: - while True: - time.sleep(9001) - except KeyboardInterrupt: - print("OPYTHAT") - tcpsrv.shutdown() - print("nailed it") + SvcHub(al).run() if __name__ == "__main__": diff --git a/copyparty/broker_mp.py b/copyparty/broker_mp.py new file mode 100644 index 00000000..d8278b10 --- /dev/null +++ b/copyparty/broker_mp.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python +# coding: utf-8 +from __future__ import print_function, unicode_literals + +import time +import threading +import multiprocessing as mp + +from .__init__ import PY2, WINDOWS +from .broker_mpw import MpWorker + + +if PY2 and not WINDOWS: + from multiprocessing.reduction import ForkingPickler + from StringIO import StringIO as MemesIO # pylint: disable=import-error + + +class BrokerMp(object): + """external api; manages MpWorkers""" + + def __init__(self, hub): + self.hub = hub + self.log = hub.log + self.args = hub.args + + self.mutex = threading.Lock() + + self.procs = [] + + cores = self.args.j + if cores is None: + cores = mp.cpu_count() + + self.log("broker", "booting {} subprocesses".format(cores)) + for n in range(cores): + q_pend = mp.Queue(1) + q_yield = mp.Queue(64) + + proc = mp.Process(target=MpWorker, args=(q_pend, q_yield, self.args, n)) + proc.q_pend = q_pend + proc.q_yield = q_yield + proc.nid = n + proc.clients = {} + proc.workload = 0 + + thr = threading.Thread(target=self.collector, args=(proc,)) + thr.daemon = True + thr.start() + + self.procs.append(proc) + proc.start() + + if True: + thr = threading.Thread(target=self.debug_load_balancer) + thr.daemon = True + thr.start() + + def shutdown(self): + self.log("broker", "shutting down") + for proc in self.procs: + thr = threading.Thread(target=proc.q_pend.put(["shutdown"])) + thr.start() + + with self.mutex: + procs = self.procs + self.procs = [] + + while procs: + if procs[-1].is_alive(): + time.sleep(0.1) + continue + + procs.pop() + + def collector(self, proc): + while True: + msg = proc.q_yield.get() + k = msg[0] + + if k == "log": + self.log(*msg[1:]) + + elif k == "workload": + with self.mutex: + proc.workload = msg[1] + + elif k == "httpdrop": + addr = msg[1] + + with self.mutex: + del proc.clients[addr] + if not proc.clients: + proc.workload = 0 + + self.hub.tcpsrv.num_clients.add(-1) + + def put(self, retq, act, *args): + if act == "httpconn": + sck, addr = args + sck2 = sck + if PY2: + buf = MemesIO() + ForkingPickler(buf).dump(sck) + sck2 = buf.getvalue() + + proc = sorted(self.procs, key=lambda x: x.workload)[0] + proc.q_pend.put(["httpconn", sck2, addr]) + + with self.mutex: + proc.clients[addr] = 50 + proc.workload += 50 + else: + raise Exception("what is " + str(act)) + + def debug_load_balancer(self): + last = "" + while self.procs: + msg = "" + for proc in self.procs: + msg += "\033[1m{}\033[0;36m{:4}\033[0m ".format( + len(proc.clients), proc.workload + ) + + if msg != last: + last = msg + print(msg) + + time.sleep(0.1) diff --git a/copyparty/broker_mpw.py b/copyparty/broker_mpw.py new file mode 100644 index 00000000..522cf10b --- /dev/null +++ b/copyparty/broker_mpw.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +# coding: utf-8 +from __future__ import print_function, unicode_literals + +import sys +import time +import signal +import threading + +from .__init__ import PY2, WINDOWS +from .httpsrv import HttpSrv + +if PY2 and not WINDOWS: + import pickle # nosec + + +class MpWorker(object): + """one single mp instance""" + + def __init__(self, q_pend, q_yield, args, n): + self.q_pend = q_pend + self.q_yield = q_yield + self.args = args + self.n = n + + self.mutex = threading.Lock() + self.workload_thr_active = False + + # we inherited signal_handler from parent, + # replace it with something harmless + signal.signal(signal.SIGINT, self.signal_handler) + + self.httpsrv = HttpSrv(self.args, self.log) + self.httpsrv.disconnect_func = self.httpdrop + + # on winxp and some other platforms, + # use thr.join() to block all signals + thr = threading.Thread(target=self.main) + thr.daemon = True + thr.start() + thr.join() + + def signal_handler(self, signal, frame): + # print('k') + pass + + def log(self, src, msg): + self.q_yield.put(["log", src, msg]) + + def logw(self, msg): + self.log("mp{}".format(self.n), msg) + + def httpdrop(self, addr): + self.q_yield.put(["httpdrop", addr]) + + def main(self): + while True: + d = self.q_pend.get() + + # self.logw("work: [{}]".format(d[0])) + if d[0] == "shutdown": + self.logw("ok bye") + sys.exit(0) + return + + elif d[0] == "httpconn": + sck = d[1] + if PY2: + sck = pickle.loads(sck) # nosec + + self.httpsrv.accept(sck, d[2]) + + with self.mutex: + if not self.workload_thr_active: + self.workload_thr_alive = True + thr = threading.Thread(target=self.thr_workload) + thr.daemon = True + thr.start() + + else: + raise Exception("what is " + str(d[0])) + + def thr_workload(self): + """announce workloads to MpSrv (the mp controller / loadbalancer)""" + # avoid locking in extract_filedata by tracking difference here + while True: + time.sleep(0.2) + with self.mutex: + if self.httpsrv.num_clients() == 0: + # no clients rn, termiante thread + self.workload_thr_alive = False + return + + self.q_yield.put(["workload", self.httpsrv.workload]) diff --git a/copyparty/broker_thr.py b/copyparty/broker_thr.py new file mode 100644 index 00000000..250069b9 --- /dev/null +++ b/copyparty/broker_thr.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python +# coding: utf-8 +from __future__ import print_function, unicode_literals + +import threading + +from .httpsrv import HttpSrv + + +class BrokerThr(object): + """external api; behaves like BrokerMP but using plain threads""" + + def __init__(self, hub): + self.hub = hub + self.log = hub.log + self.args = hub.args + + self.mutex = threading.Lock() + + self.httpsrv = HttpSrv(self.args, self.log) + self.httpsrv.disconnect_func = self.httpdrop + + def shutdown(self): + # self.log("broker", "shutting down") + pass + + def put(self, retq, act, *args): + if act == "httpconn": + sck, addr = args + self.httpsrv.accept(sck, addr) + + else: + raise Exception("what is " + str(act)) + + def httpdrop(self, addr): + self.hub.tcpsrv.num_clients.add(-1) diff --git a/copyparty/mpsrv.py b/copyparty/mpsrv.py deleted file mode 100644 index 84c04fa5..00000000 --- a/copyparty/mpsrv.py +++ /dev/null @@ -1,216 +0,0 @@ -#!/usr/bin/env python -# coding: utf-8 -from __future__ import print_function, unicode_literals - -import sys -import time -import signal -import threading -import multiprocessing as mp - -from .__init__ import PY2, WINDOWS -from .httpsrv import HttpSrv - -if PY2 and not WINDOWS: - from multiprocessing.reduction import ForkingPickler - from StringIO import StringIO as MemesIO # pylint: disable=import-error - import pickle # nosec - - -class MpWorker(object): - """ - one single mp instance, wraps one HttpSrv, - the HttpSrv api exposed to TcpSrv proxies like - MpSrv -> (this) -> HttpSrv - """ - - def __init__(self, q_pend, q_yield, args, n): - self.q_pend = q_pend - self.q_yield = q_yield - self.args = args - self.n = n - - self.mutex = threading.Lock() - self.workload_thr_active = False - - # we inherited signal_handler from parent, - # replace it with something harmless - signal.signal(signal.SIGINT, self.signal_handler) - - # on winxp and some other platforms, - # use thr.join() to block all signals - thr = threading.Thread(target=self.main) - thr.daemon = True - thr.start() - thr.join() - - def signal_handler(self, signal, frame): - # print('k') - pass - - def log(self, src, msg): - self.q_yield.put(["log", src, msg]) - - def logw(self, msg): - self.log("mp{}".format(self.n), msg) - - def disconnect_cb(self, addr): - self.q_yield.put(["dropclient", addr]) - - def main(self): - self.httpsrv = HttpSrv(self.args, self.log) - self.httpsrv.disconnect_func = self.disconnect_cb - - while True: - d = self.q_pend.get() - - # self.logw("work: [{}]".format(d[0])) - if d[0] == "shutdown": - self.logw("ok bye") - sys.exit(0) - return - - sck = d[1] - if PY2: - sck = pickle.loads(sck) # nosec - - self.httpsrv.accept(sck, d[2]) - - with self.mutex: - if not self.workload_thr_active: - self.workload_thr_alive = True - thr = threading.Thread(target=self.thr_workload) - thr.daemon = True - thr.start() - - def thr_workload(self): - """announce workloads to MpSrv (the mp controller / loadbalancer)""" - # avoid locking in extract_filedata by tracking difference here - while True: - time.sleep(0.2) - with self.mutex: - if self.httpsrv.num_clients() == 0: - # no clients rn, termiante thread - self.workload_thr_alive = False - return - - self.q_yield.put(["workload", self.httpsrv.workload]) - - -class MpSrv(object): - """ - same api as HttpSrv except uses multiprocessing to dodge gil, - a collection of MpWorkers are made (one per subprocess) - and each MpWorker creates one actual HttpSrv - """ - - def __init__(self, args, log_func): - self.log = log_func - self.args = args - - self.disconnect_func = None - self.mutex = threading.Lock() - - self.procs = [] - - cores = args.j - if cores is None: - cores = mp.cpu_count() - - self.log("mpsrv", "booting {} subprocesses".format(cores)) - for n in range(cores): - q_pend = mp.Queue(1) - q_yield = mp.Queue(64) - - proc = mp.Process(target=MpWorker, args=(q_pend, q_yield, args, n)) - proc.q_pend = q_pend - proc.q_yield = q_yield - proc.nid = n - proc.clients = {} - proc.workload = 0 - - thr = threading.Thread(target=self.collector, args=(proc,)) - thr.daemon = True - thr.start() - - self.procs.append(proc) - proc.start() - - if True: - thr = threading.Thread(target=self.debug_load_balancer) - thr.daemon = True - thr.start() - - def num_clients(self): - with self.mutex: - return sum(len(x.clients) for x in self.procs) - - def shutdown(self): - self.log("mpsrv", "shutting down") - for proc in self.procs: - thr = threading.Thread(target=proc.q_pend.put(["shutdown"])) - thr.start() - - with self.mutex: - procs = self.procs - self.procs = [] - - while procs: - if procs[-1].is_alive(): - time.sleep(0.1) - continue - - procs.pop() - - def collector(self, proc): - while True: - msg = proc.q_yield.get() - k = msg[0] - - if k == "log": - self.log(*msg[1:]) - - if k == "workload": - with self.mutex: - proc.workload = msg[1] - - if k == "dropclient": - addr = msg[1] - - with self.mutex: - del proc.clients[addr] - if not proc.clients: - proc.workload = 0 - - if self.disconnect_func: - self.disconnect_func(addr) # pylint: disable=not-callable - - def accept(self, sck, addr): - proc = sorted(self.procs, key=lambda x: x.workload)[0] - - sck2 = sck - if PY2: - buf = MemesIO() - ForkingPickler(buf).dump(sck) - sck2 = buf.getvalue() - - proc.q_pend.put(["socket", sck2, addr]) - - with self.mutex: - proc.clients[addr] = 50 - proc.workload += 50 - - def debug_load_balancer(self): - last = "" - while self.procs: - msg = "" - for proc in self.procs: - msg += "\033[1m{}\033[0;36m{:4}\033[0m ".format( - len(proc.clients), proc.workload - ) - - if msg != last: - last = msg - print(msg) - - time.sleep(0.1) diff --git a/copyparty/msgsvc.py b/copyparty/msgsvc.py deleted file mode 100644 index d18554f1..00000000 --- a/copyparty/msgsvc.py +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/env python -# coding: utf-8 -from __future__ import print_function, unicode_literals - - -class MsgSvc(object): - def __init__(self, log_func): - self.log_func = log_func - print("hi") - - def put(self, msg): - if msg[0] == "log": - return self.log_func(*msg[1:]) - - raise Exception("bad msg type: " + str(msg)) diff --git a/copyparty/svchub.py b/copyparty/svchub.py new file mode 100644 index 00000000..af3c7a30 --- /dev/null +++ b/copyparty/svchub.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python +# coding: utf-8 +from __future__ import print_function, unicode_literals + +import sys +import time +import threading +import multiprocessing as mp +from datetime import datetime, timedelta +import calendar + +from .__init__ import PY2, WINDOWS +from .tcpsrv import TcpSrv + + +class SvcHub(object): + """ + Hosts all services which cannot be parallelized due to reliance on monolithic resources. + Creates a Broker which does most of the heavy stuff; hosted services can use this to perform work: + hub.broker.put(retq, action, arg1, argN). + + Either BrokerThr (plain threads) or BrokerMP (multiprocessing) is used depending on configuration. + To receive any output returned by action, provide a queue-object for retq, else None. + """ + + def __init__(self, args): + self.args = args + + self.log_mutex = threading.Lock() + self.next_day = 0 + + # initiate all services to manage + self.tcpsrv = TcpSrv(self) + + # decide which worker impl to use + if self.check_mp_enable(): + from .broker_mp import BrokerMp as Broker + else: + self.log("root", "cannot efficiently use multiple CPU cores") + from .broker_thr import BrokerThr as Broker + + self.broker = Broker(self) + + def run(self): + thr = threading.Thread(target=self.tcpsrv.run) + thr.daemon = True + thr.start() + + # winxp/py2.7 support: thr.join() kills signals + try: + while True: + time.sleep(9001) + except KeyboardInterrupt: + print("OPYTHAT") + self.tcpsrv.shutdown() + self.broker.shutdown() + print("nailed it") + + def log(self, src, msg): + """handles logging from all components""" + now = time.time() + if now >= self.next_day: + dt = datetime.utcfromtimestamp(now) + print("\033[36m{}\033[0m".format(dt.strftime("%Y-%m-%d"))) + + # unix timestamp of next 00:00:00 (leap-seconds safe) + day_now = dt.day + while dt.day == day_now: + dt += timedelta(hours=12) + + dt = dt.replace(hour=0, minute=0, second=0) + self.next_day = calendar.timegm(dt.utctimetuple()) + + with self.log_mutex: + ts = datetime.utcfromtimestamp(now).strftime("%H:%M:%S") + print("\033[36m{} \033[33m{:21} \033[0m{}".format(ts, src, msg)) + + def check_mp_support(self): + vmin = sys.version_info[1] + if WINDOWS: + msg = "need python 3.3 or newer for multiprocessing;" + if PY2: + # py2 pickler doesn't support winsock + return msg + elif vmin < 3: + return msg + else: + msg = "need python 2.7 or 3.3+ for multiprocessing;" + if not PY2 and vmin < 3: + return msg + + try: + x = mp.Queue(1) + x.put(["foo", "bar"]) + if x.get()[0] != "foo": + raise Exception() + except: + return "multiprocessing is not supported on your platform;" + + return None + + def check_mp_enable(self): + if self.args.j == 0: + self.log("root", "multiprocessing disabled by argument -j 0;") + return False + + try: + # support vscode debugger (bonus: same behavior as on windows) + mp.set_start_method("spawn", True) + except AttributeError: + # py2.7 probably, anyways dontcare + pass + + err = self.check_mp_support() + if not err: + return True + else: + self.log("root", err) + return False diff --git a/copyparty/tcpsrv.py b/copyparty/tcpsrv.py index a65706ff..0473808d 100644 --- a/copyparty/tcpsrv.py +++ b/copyparty/tcpsrv.py @@ -3,30 +3,24 @@ from __future__ import print_function, unicode_literals import re -import sys import time import socket -import threading -import multiprocessing as mp -from datetime import datetime, timedelta -import calendar -from .__init__ import PY2, WINDOWS -from .util import chkcmd +from .util import chkcmd, Counter class TcpSrv(object): """ - toplevel component starting everything else, - tcplistener which forwards clients to httpsrv - (through mpsrv if platform provides support) + tcplistener which forwards clients to Hub + which then uses the least busy HttpSrv to handle it """ - def __init__(self, args): - self.args = args + def __init__(self, hub): + self.hub = hub + self.args = hub.args + self.log = hub.log - self.log_mutex = threading.Lock() - self.next_day = 0 + self.num_clients = Counter() ip = "127.0.0.1" if self.args.i == ip: @@ -36,7 +30,7 @@ class TcpSrv(object): for ip, desc in sorted(eps.items(), key=lambda x: x[1]): self.log( - "root", + "tcpsrv", "available @ http://{}:{}/ (\033[33m{}\033[0m)".format( ip, self.args.p, desc ), @@ -59,84 +53,19 @@ class TcpSrv(object): def run(self): self.srv.listen(self.args.nc) - self.log("root", "listening @ {0}:{1}".format(self.args.i, self.args.p)) + self.log("tcpsrv", "listening @ {0}:{1}".format(self.args.i, self.args.p)) - self.httpsrv = self.create_server() while True: - if self.httpsrv.num_clients() >= self.args.nc: + if self.num_clients.v >= self.args.nc: time.sleep(0.1) continue sck, addr = self.srv.accept() - self.httpsrv.accept(sck, addr) + self.num_clients.add() + self.hub.broker.put(None, "httpconn", sck, addr) def shutdown(self): - self.httpsrv.shutdown() - - def check_mp_support(self): - vmin = sys.version_info[1] - if WINDOWS: - msg = "need python 3.3 or newer for multiprocessing;" - if PY2: - # py2 pickler doesn't support winsock - return msg - elif vmin < 3: - return msg - else: - msg = "need python 2.7 or 3.3+ for multiprocessing;" - if not PY2 and vmin < 3: - return msg - - try: - x = mp.Queue(1) - x.put(["foo", "bar"]) - if x.get()[0] != "foo": - raise Exception() - except: - return "multiprocessing is not supported on your platform;" - - return "" - - def create_server(self): - if self.args.j == 0: - self.log("root", "multiprocessing disabled by argument -j 0;") - return self.create_threading_server() - - err = self.check_mp_support() - if err: - self.log("root", err) - return self.create_threading_server() - - return self.create_multiprocessing_server() - - def create_threading_server(self): - from .httpsrv import HttpSrv - - self.log("root", "cannot efficiently use multiple CPU cores") - return HttpSrv(self.args, self.log) - - def create_multiprocessing_server(self): - from .mpsrv import MpSrv - - return MpSrv(self.args, self.log) - - def log(self, src, msg): - now = time.time() - if now >= self.next_day: - dt = datetime.utcfromtimestamp(now) - print("\033[36m{}\033[0m".format(dt.strftime("%Y-%m-%d"))) - - # unix timestamp of next 00:00:00 (leap-seconds safe) - day_now = dt.day - while dt.day == day_now: - dt += timedelta(hours=12) - - dt = dt.replace(hour=0, minute=0, second=0) - self.next_day = calendar.timegm(dt.utctimetuple()) - - with self.log_mutex: - ts = datetime.utcfromtimestamp(now).strftime("%H:%M:%S") - print("\033[36m{} \033[33m{:21} \033[0m{}".format(ts, src, msg)) + self.log("tcpsrv", "ok bye") def detect_interfaces(self, ext_ip): eps = {} @@ -148,7 +77,7 @@ class TcpSrv(object): ip_addr = None if ip_addr: - r = re.compile("^\s+inet ([^ ]+)/.* (.*)") + r = re.compile(r"^\s+inet ([^ ]+)/.* (.*)") for ln in ip_addr.split("\n"): try: ip, dev = r.match(ln.rstrip()).groups() diff --git a/copyparty/util.py b/copyparty/util.py index 0b8f44c5..df7fb9cc 100644 --- a/copyparty/util.py +++ b/copyparty/util.py @@ -5,23 +5,38 @@ from __future__ import print_function, unicode_literals import re import sys import hashlib +import threading import subprocess as sp # nosec from .__init__ import PY2 +from .stolen import surrogateescape if not PY2: from urllib.parse import unquote_to_bytes as unquote from urllib.parse import quote_from_bytes as quote else: from urllib import unquote # pylint: disable=no-name-in-module - from urllib import quote + from urllib import quote # pylint: disable=no-name-in-module -from .stolen import surrogateescape surrogateescape.register_surrogateescape() FS_ENCODING = sys.getfilesystemencoding() +class Counter(object): + def __init__(self, v=0): + self.v = v + self.mutex = threading.Lock() + + def add(self, delta=1): + with self.mutex: + self.v += delta + + def set(self, absval): + with self.mutex: + self.v = absval + + class Unrecv(object): """ undo any number of socket recv ops