From f5b1a2065e421da46fb58d374c18e0accd19afb6 Mon Sep 17 00:00:00 2001 From: ed Date: Wed, 8 Sep 2021 21:07:34 +0000 Subject: [PATCH] multipart-parser needs exact reads --- copyparty/util.py | 85 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 76 insertions(+), 9 deletions(-) diff --git a/copyparty/util.py b/copyparty/util.py index 1990d816..ea2ad0cb 100644 --- a/copyparty/util.py +++ b/copyparty/util.py @@ -169,7 +169,7 @@ class Cooldown(object): return ret -class Unrecv(object): +class _Unrecv(object): """ undo any number of socket recv ops """ @@ -189,10 +189,68 @@ class Unrecv(object): except: return b"" + def recv_ex(self, nbytes): + """read an exact number of bytes""" + ret = self.recv(nbytes) + while len(ret) < nbytes: + buf = self.recv(nbytes - len(ret)) + if not buf: + break + + ret += buf + + return ret + def unrecv(self, buf): self.buf = buf + self.buf +class _LUnrecv(object): + """ + with expensive debug logging + """ + + def __init__(self, s): + self.s = s + self.buf = b"" + + def recv(self, nbytes): + if self.buf: + ret = self.buf[:nbytes] + self.buf = self.buf[nbytes:] + m = "\033[0;7mur:pop:\033[0;1;32m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m\n" + print(m.format(ret, self.buf), end="") + return ret + + try: + ret = self.s.recv(nbytes) + m = "\033[0;7mur:recv\033[0;1;33m {}\033[0m\n" + print(m.format(ret), end="") + return ret + except: + return b"" + + def recv_ex(self, nbytes): + """read an exact number of bytes""" + ret = self.recv(nbytes) + while len(ret) < nbytes: + buf = self.recv(nbytes - len(ret)) + if not buf: + break + + ret += buf + + return ret + + def unrecv(self, buf): + self.buf = buf + self.buf + m = "\033[0;7mur:push\033[0;1;31m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m\n" + print(m.format(buf, self.buf), end="") + + +Unrecv = _Unrecv + + class ProgressPrinter(threading.Thread): """ periodically print progress info without linefeeds @@ -587,19 +645,21 @@ class MultipartParser(object): yields [fieldname, unsanitized_filename, fieldvalue] where fieldvalue yields chunks of data """ - while True: + run = True + while run: fieldname, filename = self._read_header() yield [fieldname, filename, self._read_data()] - tail = self.sr.recv(2) + tail = self.sr.recv_ex(2) if tail == b"--": # EOF indicated by this immediately after final boundary - self.sr.recv(2) - return + tail = self.sr.recv_ex(2) + run = False if tail != b"\r\n": - raise Pebkac(400, "protocol error after field value") + m = "protocol error after field value: want b'\\r\\n', got {!r}" + raise Pebkac(400, m.format(tail)) def _read_value(self, iterator, max_len): ret = b"" @@ -985,8 +1045,12 @@ def read_socket_chunked(sr, log=None): raise Pebkac(400, err) if chunklen == 0: - sr.recv(2) # \r\n after final chunk - return + x = sr.recv_ex(2) + if x == b"\r\n": + return + + m = "protocol error after final chunk: want b'\\r\\n', got {!r}" + raise Pebkac(400, m.format(x)) if log: log("receiving {} byte chunk".format(chunklen)) @@ -994,7 +1058,10 @@ def read_socket_chunked(sr, log=None): for chunk in read_socket(sr, chunklen): yield chunk - sr.recv(2) # \r\n after each chunk too + x = sr.recv_ex(2) + if x != b"\r\n": + m = "protocol error in chunk separator: want b'\\r\\n', got {!r}" + raise Pebkac(400, m.format(x)) def yieldfile(fn):