diff --git a/.vscode/launch.json b/.vscode/launch.json index d3704cf1..f31f411a 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -14,6 +14,10 @@ "-nc", "4", "-nw", + "-a", + "ed:wark", + "-v", + "/home/ed/inc:inc:r:aed" ] }, { diff --git a/.vscode/settings.json b/.vscode/settings.json index 4e813845..53c2d963 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -34,4 +34,5 @@ ], "python.linting.pylintEnabled": true, "python.linting.enabled": true, + "python.pythonPath": "/usr/bin/python3", } \ No newline at end of file diff --git a/copyparty/authsrv.py b/copyparty/authsrv.py index 0c61ff00..b3b6f892 100644 --- a/copyparty/authsrv.py +++ b/copyparty/authsrv.py @@ -116,6 +116,18 @@ class VFS(object): return [absreal, virt_vis] + def user_tree(self, uname, readable=False, writable=False): + ret = [] + opt1 = readable and uname in self.uread + opt2 = writable and uname in self.uwrite + if opt1 or opt2: + ret.append(self.vpath) + + for _, vn in sorted(self.nodes.items()): + ret.extend(vn.user_tree(uname, readable, writable)) + + return ret + class AuthSrv(object): """verifies users against given paths""" diff --git a/copyparty/httpcli.py b/copyparty/httpcli.py index 97cd0199..e97a38c1 100644 --- a/copyparty/httpcli.py +++ b/copyparty/httpcli.py @@ -5,19 +5,24 @@ from __future__ import print_function import time import hashlib import mimetypes +import jinja2 from .__init__ import * from .util import * -if not PY2: +if PY2: + from cStringIO import StringIO as BytesIO +else: unicode = str + from io import BytesIO as BytesIO class HttpCli(object): - def __init__(self, sck, addr, args, log_func): + def __init__(self, sck, addr, args, auth, log_func): self.s = sck self.addr = addr self.args = args + self.auth = auth self.sr = Unrecv(sck) self.bufsz = 1024 * 32 @@ -27,13 +32,21 @@ class HttpCli(object): self.log_func = log_func self.log_src = "{} \033[36m{}".format(addr[0], addr[1]).ljust(26) + with open(self.respath("splash.html"), "rb") as f: + self.tpl_mounts = jinja2.Template(f.read().decode("utf-8")) + + def respath(self, res_name): + return os.path.join(E.mod, "web", res_name) + def log(self, msg): self.log_func(self.log_src, msg) def run(self): while self.ok: - headerlines = self.read_header() - if not self.ok: + try: + headerlines = read_header(self.sr) + except: + self.ok = False return self.headers = {} @@ -48,7 +61,27 @@ class HttpCli(object): k, v = header_line.split(":", 1) self.headers[k.lower()] = v.strip() - # self.bufsz = int(self.req.split('/')[-1]) * 1024 + self.uname = "*" + if "cookie" in self.headers: + cookies = self.headers["cookie"].split(";") + for k, v in [x.split("=", 1) for x in cookies]: + if k != "cppwd": + continue + + v = unescape_cookie(v) + if not v in self.auth.iuser: + msg = u'bad_cpwd "{}"'.format(v) + nuke = u"Set-Cookie: cppwd=x; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT" + self.loud_reply(msg, headers=[nuke]) + return + + self.uname = self.auth.iuser[v] + + if self.uname: + self.rvol = self.auth.vfs.user_tree(self.uname, readable=True) + self.wvol = self.auth.vfs.user_tree(self.uname, writable=True) + print(self.rvol) + print(self.wvol) if mode == "GET": self.handle_get() @@ -62,132 +95,96 @@ class HttpCli(object): self.ok = False self.s.close() - def read_header(self): - ret = b"" - while True: - if ret.endswith(b"\r\n\r\n"): - break - elif ret.endswith(b"\r\n\r"): - n = 1 - elif ret.endswith(b"\r\n"): - n = 2 - elif ret.endswith(b"\r"): - n = 3 - else: - n = 4 - - buf = self.sr.recv(n) - if not buf: - self.panic("headers") - break - - ret += buf - - return ret[:-4].decode("utf-8", "replace").split("\r\n") - - def reply(self, body, status="200 OK", mime="text/html"): - header = "HTTP/1.1 {}\r\nConnection: Keep-Alive\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n".format( - status, mime, len(body) - ).encode( - "utf-8" - ) + def reply(self, body, status="200 OK", mime="text/html", headers=[]): + # TODO something to reply with user-supplied values safely + response = [ + u"HTTP/1.1 " + status, + u"Connection: Keep-Alive", + u"Content-Type: " + mime, + u"Content-Length: " + str(len(body)), + ] + response.extend(headers) + response_str = u"\r\n".join(response).encode("utf-8") if self.ok: - self.s.send(header + body) + self.s.send(response_str + b"\r\n\r\n" + body) return body - def loud_reply(self, body, **kwargs): + def loud_reply(self, body, *args, **kwargs): self.log(body.rstrip()) - self.reply(b"
" + body.encode("utf-8"), **kwargs)
-
-    def send_file(self, path):
-        sz = os.path.getsize(path)
-        mime = mimetypes.guess_type(path)[0]
-        header = "HTTP/1.1 200 OK\r\nConnection: Keep-Alive\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n".format(
-            mime, sz
-        ).encode(
-            "utf-8"
-        )
-
-        if self.ok:
-            self.s.send(header)
-
-        with open(path, "rb") as f:
-            while self.ok:
-                buf = f.read(4096)
-                if not buf:
-                    break
-
-                self.s.send(buf)
+        self.reply(b"
" + body.encode("utf-8"), *list(args), **kwargs)
 
     def handle_get(self):
         self.log("")
         self.log("GET  {0} {1}".format(self.addr[0], self.req))
 
-        static_path = os.path.join(E.mod, "web", self.req.split("?")[0][1:])
+        if self.req.startswith("/.cpr/"):
+            static_path = os.path.join(E.mod, "web", self.req.split("?")[0][6:])
 
-        if os.path.isfile(static_path):
-            return self.send_file(static_path)
+            if os.path.isfile(static_path):
+                return self.tx_file(static_path)
 
         if self.req == "/":
-            return self.send_file(os.path.join(E.mod, "web/splash.html"))
+            return self.tx_mounts()
 
-        return self.loud_reply("404 not found", status="404 Not Found")
+        return self.loud_reply("404 not found", "404 Not Found")
 
     def handle_post(self):
         self.log("")
         self.log("POST {0} {1}".format(self.addr[0], self.req))
 
-        nullwrite = self.args.nw
-
         try:
             if self.headers["expect"].lower() == "100-continue":
                 self.s.send(b"HTTP/1.1 100 Continue\r\n\r\n")
         except:
             pass
 
-        form_segm = self.read_header()
-        if not self.ok:
+        self.parser = MultipartParser(self.log, self.sr, self.headers)
+        self.parser.parse()
+
+        act = self.parser.require("act", 64)
+
+        if act == u"bput":
+            self.handle_plain_upload()
             return
 
-        self.boundary = b"\r\n" + form_segm[0].encode("utf-8")
-        for ln in form_segm[1:]:
-            self.log(ln)
+        if act == u"login":
+            self.handle_login()
+            return
 
-        fn = os.devnull
-        fn0 = "inc.{0:.6f}".format(time.time())
+        raise Pebkac('invalid action "{}"'.format(act))
+
+    def handle_login(self):
+        pwd = self.parser.require("cppwd", 64)
+        if not pwd in self.auth.iuser:
+            h = [u"Set-Cookie: cppwd=x; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT"]
+            self.loud_reply(u'bad_ppwd "{}"'.format(pwd), headers=h)
+        else:
+            h = ["Set-Cookie: cppwd={}; Path=/".format(pwd)]
+            self.loud_reply(u"login_ok", headers=h)
+
+    def handle_plain_upload(self):
+        nullwrite = self.args.nw
 
         files = []
         t0 = time.time()
-        for nfile in range(99):
+        for nfile, (p_field, p_file, p_data) in enumerate(self.parser.gen):
+            fn = os.devnull
             if not nullwrite:
-                fn = "{0}.{1}".format(fn0, nfile)
+                fn = sanitize_fn(p_file)
+                # TODO broker which avoid this race
+                # and provides a new filename if taken
+                if os.path.exists(fn):
+                    fn += ".{:.6f}".format(time.time())
 
             with open(fn, "wb") as f:
                 self.log("writing to {0}".format(fn))
-                sz, sha512 = self.handle_multipart(f)
+                sz, sha512 = hashcopy(self, p_data, f)
                 if sz == 0:
                     break
 
                 files.append([sz, sha512])
 
-            buf = self.sr.recv(2)
-
-            if buf == b"--":
-                # end of multipart
-                break
-
-            if buf != b"\r\n":
-                return self.loud_reply(u"protocol error")
-
-            header = self.read_header()
-            if not self.ok:
-                break
-
-            form_segm += header
-            for ln in header:
-                self.log(ln)
-
         td = time.time() - t0
         sz_total = sum(x[0] for x in files)
         spd = (sz_total / td) / (1024 * 1024)
@@ -206,14 +203,15 @@ class HttpCli(object):
         self.loud_reply(msg)
 
         if not nullwrite:
-            with open(fn0 + ".txt", "wb") as f:
+            # TODO this is bad
+            log_fn = "up.{:.6f}.txt".format(t0)
+            with open(log_fn, "wb") as f:
                 f.write(
                     (
                         u"\n".join(
                             unicode(x)
                             for x in [
                                 u":".join(unicode(x) for x in self.addr),
-                                u"\n".join(form_segm),
                                 msg.rstrip(),
                             ]
                         )
@@ -221,77 +219,26 @@ class HttpCli(object):
                     ).encode("utf-8")
                 )
 
-        try:
-            # TODO: check if actually part of multipart footer
-            buf = self.sr.recv(2)
-            if buf != b"\r\n":
-                raise Exception("oh")
-        except:
-            self.log("client is done")
-            self.s.close()
+    def tx_file(self, path):
+        sz = os.path.getsize(path)
+        mime = mimetypes.guess_type(path)[0]
+        header = "HTTP/1.1 200 OK\r\nConnection: Keep-Alive\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n".format(
+            mime, sz
+        ).encode(
+            "utf-8"
+        )
 
-    def handle_multipart(self, ofd):
-        tlen = 0
-        hashobj = hashlib.sha512()
-        for buf in self.extract_filedata():
-            tlen += len(buf)
-            hashobj.update(buf)
-            ofd.write(buf)
+        if self.ok:
+            self.s.send(header)
 
-        return tlen, hashobj.hexdigest()
-
-    def extract_filedata(self):
-        u32_lim = int((2 ** 31) * 0.9)
-        blen = len(self.boundary)
-        bufsz = self.bufsz
-        while True:
-            if self.workload > u32_lim:
-                # reset to prevent overflow
-                self.workload = 100
-
-            buf = self.sr.recv(bufsz)
-            self.workload += 1
-            if not buf:
-                # abort: client disconnected
-                self.panic("outer")
-                return
-
-            while True:
-                ofs = buf.find(self.boundary)
-                if ofs != -1:
-                    self.sr.unrecv(buf[ofs + blen :])
-                    yield buf[:ofs]
-                    return
-
-                d = len(buf) - blen
-                if d > 0:
-                    # buffer growing large; yield everything except
-                    # the part at the end (maybe start of boundary)
-                    yield buf[:d]
-                    buf = buf[d:]
-
-                # look for boundary near the end of the buffer
-                for n in range(1, len(buf) + 1):
-                    if not buf[-n:] in self.boundary:
-                        n -= 1
-                        break
-
-                if n == 0 or not self.boundary.startswith(buf[-n:]):
-                    # no boundary contents near the buffer edge
+        with open(path, "rb") as f:
+            while self.ok:
+                buf = f.read(4096)
+                if not buf:
                     break
 
-                if blen == n:
-                    # EOF: found boundary
-                    yield buf[:-n]
-                    return
+                self.s.send(buf)
 
-                buf2 = self.sr.recv(bufsz)
-                self.workload += 1
-                if not buf2:
-                    # abort: client disconnected
-                    self.panic("inner")
-                    return
-
-                buf += buf2
-
-            yield buf
+    def tx_mounts(self):
+        html = self.tpl_mounts.render(this=self)
+        self.reply(html.encode("utf-8"))
diff --git a/copyparty/httpsrv.py b/copyparty/httpsrv.py
index 7196d26c..7c489543 100644
--- a/copyparty/httpsrv.py
+++ b/copyparty/httpsrv.py
@@ -42,7 +42,9 @@ class HttpSrv(object):
     def thr_client(self, sck, addr, log):
         """thread managing one tcp client"""
         try:
-            cli = HttpCli(sck, addr, self.args, log)
+            # TODO HttpConn between HttpSrv and HttpCli
+            # to ensure no state is kept between http requests
+            cli = HttpCli(sck, addr, self.args, self.auth, log)
             with self.mutex:
                 self.clients[cli] = 0
                 self.workload += 50
diff --git a/copyparty/web/splash.css b/copyparty/web/splash.css
index dd48bc2f..ec30261a 100644
--- a/copyparty/web/splash.css
+++ b/copyparty/web/splash.css
@@ -7,9 +7,21 @@ html, body, #wrap {
 	max-width: 40em;
 	margin: 2em auto;
 	padding: 0 1em 3em 1em;
+	line-height: 1.3em;
 }
 h1 {
 	border-bottom: 1px solid #ccc;
 	margin: 2em 0 .4em 0;
 	padding: 0 0 .2em 0;
+}
+li {
+	margin: 1em 0;
+}
+a {
+	color: #047;
+	background: #eee;
+	background: linear-gradient(to bottom, #eee, #ddd 49%, #ccc 50%, #eee);
+	border-bottom: 1px solid #aaa;
+	border-radius: .2em;
+	padding: .2em .5em;
 }
\ No newline at end of file
diff --git a/copyparty/web/splash.html b/copyparty/web/splash.html
index 8c247f93..58f51176 100644
--- a/copyparty/web/splash.html
+++ b/copyparty/web/splash.html
@@ -6,14 +6,44 @@
     copyparty
     
     
-    
+    
 
 
 
     
-

hello world

+

hello {{ this.uname }}

+ +

you can browse these:

+
    + {% for mp in this.rvol %} +
  • /{{ mp }}
  • + {% endfor %} +
+ +

you can upload to:

+
    + {% for mp in this.wvol %} +
  • /{{ mp }}
  • + {% endfor %} +
+ +

login for more:

+
    +
    + + + +
    +
+ +

[debug] fallback upload

+
+ +
+
+
- + \ No newline at end of file diff --git a/copyparty/web/splash.js b/copyparty/web/splash.js index 0f4dddb0..e69de29b 100644 --- a/copyparty/web/splash.js +++ b/copyparty/web/splash.js @@ -1 +0,0 @@ -document.getElementsByTagName('h1')[0].insertAdjacentHTML('afterend', '')