diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..90eaec19 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ +* text eol=lf + +*.png binary diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..aab04bc8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,23 @@ +# python +__pycache__/ +*.py[cod] +*$py.class +MANIFEST.in +MANIFEST +copyparty.egg-info/ +buildenv/ +build/ +dist/ +*.rst + +# sublime +*.sublime-workspace + +# vscode +.vscode + +# winmerge +*.bak + +# other licenses +contrib/ diff --git a/README.md b/README.md index baa13143..d25a40c7 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,18 @@ -# copyparty -⇆🎉 http file sharing hub (py2/py3) +# ⇆🎉 copyparty + +* http file sharing hub (py2/py3) +* MIT-Licensed, 2019-05-26, ed @ irc.rizon.net + +## status + +* [x] sanic multipart parser +* [x] load balancer (multiprocessing) +* [ ] upload +* [ ] download +* [ ] browser +* [ ] thumbnails +* [ ] download as zip +* [ ] volumes +* [ ] accounts + +conclusion: don't bother diff --git a/copyparty/__init__.py b/copyparty/__init__.py new file mode 100644 index 00000000..60bd1360 --- /dev/null +++ b/copyparty/__init__.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python +# coding: utf-8 +from __future__ import print_function + +import platform +import sys +import os + +WINDOWS = platform.system() == "Windows" +PY2 = sys.version_info[0] == 2 +if PY2: + sys.dont_write_bytecode = True + diff --git a/copyparty/__main__.py b/copyparty/__main__.py new file mode 100644 index 00000000..ee2e2dde --- /dev/null +++ b/copyparty/__main__.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +# coding: utf-8 +from __future__ import print_function + +"""copyparty: http file sharing hub (py2/py3)""" +__author__ = "ed " +__copyright__ = 2019 +__license__ = "MIT" +__url__ = "https://github.com/9001/copyparty/" + +import time +import argparse +import threading +from textwrap import dedent +import multiprocessing as mp + +from .__version__ import * +from .tcpsrv import * + + +class RiceFormatter(argparse.HelpFormatter): + def _get_help_string(self, action): + """ + same as ArgumentDefaultsHelpFormatter(HelpFormatter) + except the help += [...] line now has colors + """ + help = action.help + if "%(default)" not in action.help: + if action.default is not argparse.SUPPRESS: + defaulting_nargs = [argparse.OPTIONAL, argparse.ZERO_OR_MORE] + if action.option_strings or action.nargs in defaulting_nargs: + help += "\033[36m (default: \033[35m%(default)s\033[36m)\033[0m" + return help + + def _fill_text(self, text, width, indent): + """same as RawDescriptionHelpFormatter(HelpFormatter)""" + return "".join(indent + line + "\n" for line in text.splitlines()) + + +def main(): + try: + # support vscode debugger (bonus: same behavior as on windows) + mp.set_start_method("spawn", True) + except: + # py2.7 probably + pass + + ap = argparse.ArgumentParser( + formatter_class=RiceFormatter, + prog="copyparty", + description="http file sharing hub v{} ({})".format(S_VERSION, S_BUILD_DT), + epilog=dedent( + """ + -a takes username:password, + -v takes path:permset:permset:... where "permset" is + accesslevel followed by username (no separator) + + example:\033[35m + -a ed:hunter2 -v .:r:aed -v ../inc:w:aed \033[36m + share current directory with + * r (read-only) for everyone + * a (read+write) for ed + share ../inc with + * w (write-only) for everyone + * a (read+write) for ed \033[0m + + if no accounts or volumes are configured, + current folder will be read/write for everyone + + consider the config file for more flexible account/volume management, + including dynamic reload at runtime (and being more readable w) + """ + ), + ) + ap.add_argument("-c", metavar="PATH", type=str, help="config file") + ap.add_argument("-i", metavar="IP", type=str, default="0.0.0.0", help="ip to bind") + ap.add_argument("-p", metavar="PORT", type=int, default=1234, help="port to bind") + ap.add_argument("-nc", metavar="NUM", type=int, default=16, help="max num clients") + ap.add_argument("-j", metavar="CORES", type=int, help="max num cpu cores") + ap.add_argument("-a", metavar="ACCT", type=str, help="add account") + ap.add_argument("-v", metavar="VOL", type=str, help="add volume") + ap.add_argument("-nw", action="store_true", help="DEBUG: disable writing") + al = ap.parse_args() + + thr = threading.Thread(target=TcpSrv, args=(al,)) + thr.daemon = True + thr.start() + + while True: + time.sleep(9001) + + +if __name__ == "__main__": + main() diff --git a/copyparty/__version__.py b/copyparty/__version__.py new file mode 100644 index 00000000..f86f40d9 --- /dev/null +++ b/copyparty/__version__.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python +# coding: utf-8 + +VERSION = (0, 0, 1) +BUILD_DT = (2019, 5, 26) + +S_VERSION = ".".join(map(str, VERSION)) +S_BUILD_DT = "{0:04d}-{1:02d}-{2:02d}".format(*BUILD_DT) + +__version__ = S_VERSION +__build_dt__ = S_BUILD_DT + +# I'm all ears diff --git a/copyparty/httpcli.py b/copyparty/httpcli.py new file mode 100644 index 00000000..b76ca3ba --- /dev/null +++ b/copyparty/httpcli.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python +# coding: utf-8 +from __future__ import print_function + +import time +import hashlib + +from .__init__ import * +from .util import * + +if not PY2: + unicode = str + + +class HttpCli(object): + def __init__(self, sck, addr, args, log_func): + self.s = sck + self.addr = addr + self.args = args + + self.sr = Unrecv(sck) + self.bufsz = 1024 * 32 + self.workload = 0 + self.ok = True + + self.log_func = log_func + self.log_src = "{} \033[36m{}".format(addr[0], addr[1]).ljust(26) + + def log(self, msg): + self.log_func(self.log_src, msg) + + def run(self): + headerlines = self.read_header() + if not self.ok: + return + + self.headers = {} + mode, self.req, _ = headerlines[0].split(" ") + + for header_line in headerlines[1:]: + k, v = header_line.split(":", 1) + self.headers[k.lower()] = v.strip() + + # self.bufsz = int(self.req.split('/')[-1]) * 1024 + + if mode == "GET": + self.handle_get() + elif mode == "POST": + self.handle_post() + else: + self.loud_reply(u'invalid HTTP mode "{0}"'.format(mode)) + + def panic(self, msg): + self.log("client disconnected ({0})".format(msg).upper()) + self.ok = False + self.s.close() + + def read_header(self): + ret = b"" + while True: + if ret.endswith(b"\r\n\r\n"): + break + elif ret.endswith(b"\r\n\r"): + n = 1 + elif ret.endswith(b"\r\n"): + n = 2 + elif ret.endswith(b"\r"): + n = 3 + else: + n = 4 + + buf = self.sr.recv(n) + if not buf: + self.panic("headers") + break + + ret += buf + + return ret[:-4].decode("utf-8", "replace").split("\r\n") + + def reply(self, body): + header = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {0}\r\n\r\n".format( + len(body) + ).encode( + "utf-8" + ) + if self.ok: + self.s.send(header + body) + + self.s.close() + return body + + def loud_reply(self, body): + self.log(body.rstrip()) + self.reply(b"
" + body.encode("utf-8"))
+
+    def handle_get(self):
+        self.log("")
+        self.log("GET  {0} {1}".format(self.addr[0], self.req))
+        self.reply(
+            b'
' + ) + + def handle_post(self): + self.log("") + self.log("POST {0} {1}".format(self.addr[0], self.req)) + + nullwrite = self.args.nw + + try: + if self.headers["expect"].lower() == "100-continue": + self.s.send(b"HTTP/1.1 100 Continue\r\n\r\n") + except: + pass + + form_segm = self.read_header() + if not self.ok: + return + + self.boundary = b"\r\n" + form_segm[0].encode("utf-8") + for ln in form_segm[1:]: + self.log(ln) + + fn = "/dev/null" + fn0 = "inc.{0:.6f}".format(time.time()) + + files = [] + t0 = time.time() + for nfile in range(99): + if not nullwrite: + fn = "{0}.{1}".format(fn0, nfile) + + with open(fn, "wb") as f: + self.log("writing to {0}".format(fn)) + sz, sha512 = self.handle_multipart(f) + if sz == 0: + break + + files.append([sz, sha512]) + + buf = self.sr.recv(2) + + if buf == b"--": + # end of multipart + break + + if buf != b"\r\n": + return self.loud_reply(u"protocol error") + + header = self.read_header() + if not self.ok: + break + + form_segm += header + for ln in header: + self.log(ln) + + td = time.time() - t0 + sz_total = sum(x[0] for x in files) + spd = (sz_total / td) / (1024 * 1024) + + status = "OK" + if not self.ok: + status = "ERROR" + + msg = u"{0} // {1} bytes // {2:.3f} MiB/s\n".format(status, sz_total, spd) + + for sz, sha512 in files: + msg += u"sha512: {0} // {1} bytes\n".format(sha512[:56], sz) + # truncated SHA-512 prevents length extension attacks; + # using SHA-512/224, optionally SHA-512/256 = :64 + + self.loud_reply(msg) + + if not nullwrite: + with open(fn0 + ".txt", "wb") as f: + f.write( + ( + u"\n".join( + unicode(x) + for x in [ + u":".join(unicode(x) for x in self.addr), + u"\n".join(form_segm), + msg.rstrip(), + ] + ) + + "\n" + ).encode("utf-8") + ) + + def handle_multipart(self, ofd): + tlen = 0 + hashobj = hashlib.sha512() + for buf in self.extract_filedata(): + tlen += len(buf) + hashobj.update(buf) + ofd.write(buf) + + return tlen, hashobj.hexdigest() + + def extract_filedata(self): + u32_lim = int((2 ** 31) * 0.9) + blen = len(self.boundary) + bufsz = self.bufsz + while True: + if self.workload > u32_lim: + # reset to prevent overflow + self.workload = 100 + + buf = self.sr.recv(bufsz) + self.workload += 1 + if not buf: + # abort: client disconnected + self.panic("outer") + return + + while True: + ofs = buf.find(self.boundary) + if ofs != -1: + self.sr.unrecv(buf[ofs + blen :]) + yield buf[:ofs] + return + + d = len(buf) - blen + if d > 0: + # buffer growing large; yield everything except + # the part at the end (maybe start of boundary) + yield buf[:d] + buf = buf[d:] + + # look for boundary near the end of the buffer + for n in range(1, len(buf) + 1): + if not buf[-n:] in self.boundary: + n -= 1 + break + + if n == 0 or not self.boundary.startswith(buf[-n:]): + # no boundary contents near the buffer edge + break + + if blen == n: + # EOF: found boundary + yield buf[:-n] + return + + buf2 = self.sr.recv(bufsz) + self.workload += 1 + if not buf2: + # abort: client disconnected + self.panic("inner") + return + + buf += buf2 + + yield buf diff --git a/copyparty/httpsrv.py b/copyparty/httpsrv.py new file mode 100644 index 00000000..b696a10f --- /dev/null +++ b/copyparty/httpsrv.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +# coding: utf-8 +from __future__ import print_function + +import threading + +from .httpcli import * + + +class HttpSrv(object): + """ + handles incoming connections (parses http and produces responses) + relying on MpSrv for performance (HttpSrv is just plain threads) + """ + + def __init__(self, args, log_func): + self.log = log_func + self.args = args + + self.disconnect_func = None + self.clients = {} + self.workload = 0 + self.workload_thr_alive = False + + self.mutex = threading.Lock() + + def accept(self, sck, addr): + """takes an incoming tcp connection and creates a thread to handle it""" + thr = threading.Thread(target=self.thr_client, args=(sck, addr, self.log)) + thr.daemon = True + thr.start() + + def num_clients(self): + with self.mutex: + return len(self.clients) + + def thr_client(self, sck, addr, log): + """thread managing one tcp client""" + try: + cli = HttpCli(sck, addr, self.args, log) + with self.mutex: + self.clients[cli] = 0 + self.workload += 50 + + if not self.workload_thr_alive: + self.workload_thr_alive = True + thr = threading.Thread(target=self.thr_workload) + thr.daemon = True + thr.start() + + cli.run() + + finally: + with self.mutex: + del self.clients[cli] + + if self.disconnect_func: + self.disconnect_func(addr) + + def thr_workload(self): + """indicates the python interpreter workload caused by this HttpSrv""" + # avoid locking in extract_filedata by tracking difference here + while True: + time.sleep(0.2) + with self.mutex: + if not self.clients: + # no clients rn, termiante thread + self.workload_thr_alive = False + self.workload = 0 + return + + total = 0 + with self.mutex: + for cli in self.clients.keys(): + now = cli.workload + delta = now - self.clients[cli] + if delta < 0: + # was reset in HttpCli to prevent overflow + delta = now + + total += delta + self.clients[cli] = now + + self.workload = total diff --git a/copyparty/mpsrv.py b/copyparty/mpsrv.py new file mode 100644 index 00000000..58c2a495 --- /dev/null +++ b/copyparty/mpsrv.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python +# coding: utf-8 +from __future__ import print_function + +import time +import threading +import multiprocessing as mp +from multiprocessing.reduction import ForkingPickler +import pickle + +from .__init__ import * +from .httpsrv import * + +if PY2: + from StringIO import StringIO as MemesIO +else: + from io import BytesIO as MemesIO + + +class MpWorker(object): + """ + one single mp instance, wraps one HttpSrv, + the HttpSrv api exposed to TcpSrv proxies like + MpSrv -> (this) -> HttpSrv + """ + + def __init__(self, q_pend, q_yield, args, n): + self.q_pend = q_pend + self.q_yield = q_yield + self.args = args + self.n = n + + self.mutex = threading.Lock() + self.workload_thr_active = False + + thr = threading.Thread(target=self.main) + thr.daemon = True + thr.start() + + while True: + time.sleep(9001) + + def log(self, src, msg): + self.q_yield.put(["log", src, msg]) + + def logw(self, msg): + self.log("mp{}".format(self.n), msg) + + def disconnect_cb(self, addr): + self.q_yield.put(["dropclient", addr]) + + def main(self): + self.httpsrv = HttpSrv(self.args, self.log) + self.httpsrv.disconnect_func = self.disconnect_cb + + while True: + self.logw("awaiting work") + d = self.q_pend.get() + + self.logw("work: [{}]".format(d[0])) + if d[0] == "terminate": + self.logw("bye") + sys.exit(0) + return + + sck = pickle.loads(d[1]) + self.httpsrv.accept(sck, d[2]) + + with self.mutex: + if not self.workload_thr_active: + self.workload_thr_alive = True + thr = threading.Thread(target=self.thr_workload) + thr.daemon = True + thr.start() + + def thr_workload(self): + """announce workloads to MpSrv (the mp controller / loadbalancer)""" + # avoid locking in extract_filedata by tracking difference here + while True: + time.sleep(0.2) + with self.mutex: + if self.httpsrv.num_clients() == 0: + # no clients rn, termiante thread + self.workload_thr_alive = False + return + + self.q_yield.put(["workload", self.httpsrv.workload]) + + +class MpSrv(object): + """ + same api as HttpSrv except uses multiprocessing to dodge gil, + a collection of MpWorkers are made (one per subprocess) + and each MpWorker creates one actual HttpSrv + """ + + def __init__(self, args, log_func): + self.log = log_func + self.args = args + + self.disconnect_func = None + self.procs = [] + + self.mutex = threading.Lock() + + cores = args.j + if cores is None: + cores = mp.cpu_count() + + self.log("mpsrv", "booting {} subprocesses".format(cores)) + for n in range(cores): + q_pend = mp.Queue(1) + q_yield = mp.Queue(64) + + proc = mp.Process(target=MpWorker, args=(q_pend, q_yield, args, n)) + proc.q_pend = q_pend + proc.q_yield = q_yield + proc.nid = n + proc.clients = {} + proc.workload = 0 + + thr = threading.Thread(target=self.collector, args=(proc,)) + thr.daemon = True + thr.start() + + self.procs.append(proc) + proc.start() + + if True: + thr = threading.Thread(target=self.debug_load_balancer) + thr.daemon = True + thr.start() + + def num_clients(self): + with self.mutex: + return sum(len(x.clients) for x in self.procs) + + def collector(self, proc): + while True: + msg = proc.q_yield.get() + k = msg[0] + + if k == "log": + self.log(*msg[1:]) + + if k == "workload": + with self.mutex: + proc.workload = msg[1] + + if k == "dropclient": + addr = msg[1] + + with self.mutex: + del proc.clients[addr] + if not proc.clients: + proc.workload = 0 + + if self.disconnect_func: + self.disconnect_func(addr) + + def accept(self, sck, addr): + proc = sorted(self.procs, key=lambda x: x.workload)[0] + + # can't put unpickled sockets <3.4 + buf = MemesIO() + ForkingPickler(buf).dump(sck) + proc.q_pend.put(["socket", buf.getvalue(), addr]) + + with self.mutex: + proc.clients[addr] = 50 + proc.workload += 50 + + def debug_load_balancer(self): + while True: + msg = "" + for proc in self.procs: + msg += "{} \033[36m{}\033[0m ".format( + len(proc.clients), proc.workload + ) + + print(msg) + time.sleep(0.1) diff --git a/copyparty/msgsvc.py b/copyparty/msgsvc.py new file mode 100644 index 00000000..eab543d1 --- /dev/null +++ b/copyparty/msgsvc.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# coding: utf-8 +from __future__ import print_function + + +class MsgSvc(object): + def __init__(self, log_func): + self.log_func = log_func + print("hi") + + def put(self, msg): + if msg[0] == "log": + return self.log_func(*msg[1:]) + + raise Exception("bad msg type: " + str(msg)) diff --git a/copyparty/tcpsrv.py b/copyparty/tcpsrv.py new file mode 100644 index 00000000..a71a8f9c --- /dev/null +++ b/copyparty/tcpsrv.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +# coding: utf-8 +from __future__ import print_function + +import time +import socket +import threading +from datetime import datetime, timedelta +import calendar + +from .msgsvc import * +from .mpsrv import * + + +class TcpSrv(object): + """ + toplevel component starting everything else, + tcplistener which forwards clients to httpsrv + (through mpsrv if platform provides support) + """ + + def __init__(self, args): + self.log_mutex = threading.Lock() + self.msgsvc = MsgSvc(self.log) + self.next_day = 0 + + bind_ip = args.i + bind_port = args.p + + ip = "127.0.0.1" + if bind_ip != ip: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("10.255.255.255", 1)) + ip = s.getsockname()[0] + except: + pass + s.close() + + self.log("root", "available @ http://{0}:{1}/".format(ip, bind_port)) + + srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + srv.bind((bind_ip, bind_port)) + srv.listen(100) + + self.log("root", "listening @ {0}:{1}".format(bind_ip, bind_port)) + + if args.j == 0: + self.log("root", "multiprocessing disabled") + httpsrv = HttpSrv(args, self.log) + else: + httpsrv = MpSrv(args, self.log) + + while True: + if httpsrv.num_clients() >= args.nc: + time.sleep(0.1) + continue + + sck, addr = srv.accept() + httpsrv.accept(sck, addr) + + def log(self, src, msg): + now = time.time() + if now >= self.next_day: + dt = datetime.utcfromtimestamp(now) + print("\033[36m{}\033[0m".format(dt.strftime("%Y-%m-%d"))) + + # unix timestamp of next 00:00:00 (leap-seconds safe) + day_now = dt.day + while dt.day == day_now: + dt += timedelta(hours=12) + + dt = dt.replace(hour=0, minute=0, second=0) + self.next_day = calendar.timegm(dt.utctimetuple()) + + with self.log_mutex: + ts = datetime.utcfromtimestamp(now).strftime("%H:%M:%S") + print("\033[36m{} \033[33m{:21} \033[0m{}".format(ts, src, msg)) diff --git a/copyparty/util.py b/copyparty/util.py new file mode 100644 index 00000000..c286ffca --- /dev/null +++ b/copyparty/util.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# coding: utf-8 +from __future__ import print_function + + +class Unrecv(object): + """ + undo any number of socket recv ops + """ + + def __init__(self, s): + self.s = s + self.buf = b"" + + def recv(self, nbytes): + if self.buf: + ret = self.buf[:nbytes] + self.buf = self.buf[nbytes:] + return ret + + try: + return self.s.recv(nbytes) + except: + return b"" + + def unrecv(self, buf): + self.buf = buf + self.buf diff --git a/docs/design.txt b/docs/design.txt new file mode 100644 index 00000000..bbfd8439 --- /dev/null +++ b/docs/design.txt @@ -0,0 +1,22 @@ +need log interface + tcpsrv creates it + httpsrv must use interface + +msgsvc + simulates a multiprocessing queue + takes events from httpsrv + logging + mpsrv pops queue and forwards to this + +tcpsrv + tcp listener + pass tcp clients to worker + api to get status messages from workers + +mpsrv + uses multiprocessing to handle incoming clients + +httpsrv + takes client sockets, starts threads + takes argv acc/vol through init args + loads acc/vol from config file diff --git a/docs/notes.sh b/docs/notes.sh new file mode 100644 index 00000000..e0c3421e --- /dev/null +++ b/docs/notes.sh @@ -0,0 +1,18 @@ +## +## prep debug env (vscode embedded terminal) + +renice 20 -p $$ + + +## +## testing multiple parallel uploads +## usage: para | tee log + +para() { for s in 1 2 3 4 5 6 7 8 12 16 24 32 48 64; do echo $s; for r in {1..5}; do for ((n=0;n&1 & done; wait; echo; done; done; } + + +## +## display average speed +## usage: avg logfile + +avg() { awk 'function pr(ncsz) {if (nsmp>0) {printf "%3s %s\n", csz, sum/nsmp} csz=$1;sum=0;nsmp=0} {sub(/\r$/,"")} /^[0-9]+$/ {pr($1);next} / MiB/ {sub(/ MiB.*/,"");sub(/.* /,"");sum+=$1;nsmp++} END {pr(0)}' "$1"; }