multipart-parser needs exact reads

This commit is contained in:
ed 2021-09-08 21:07:34 +00:00
parent 5e62532295
commit f5b1a2065e

View file

@ -169,7 +169,7 @@ class Cooldown(object):
return ret return ret
class Unrecv(object): class _Unrecv(object):
""" """
undo any number of socket recv ops undo any number of socket recv ops
""" """
@ -189,10 +189,68 @@ class Unrecv(object):
except: except:
return b"" return b""
def recv_ex(self, nbytes):
"""read an exact number of bytes"""
ret = self.recv(nbytes)
while len(ret) < nbytes:
buf = self.recv(nbytes - len(ret))
if not buf:
break
ret += buf
return ret
def unrecv(self, buf): def unrecv(self, buf):
self.buf = buf + self.buf self.buf = buf + self.buf
class _LUnrecv(object):
"""
with expensive debug logging
"""
def __init__(self, s):
self.s = s
self.buf = b""
def recv(self, nbytes):
if self.buf:
ret = self.buf[:nbytes]
self.buf = self.buf[nbytes:]
m = "\033[0;7mur:pop:\033[0;1;32m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m\n"
print(m.format(ret, self.buf), end="")
return ret
try:
ret = self.s.recv(nbytes)
m = "\033[0;7mur:recv\033[0;1;33m {}\033[0m\n"
print(m.format(ret), end="")
return ret
except:
return b""
def recv_ex(self, nbytes):
"""read an exact number of bytes"""
ret = self.recv(nbytes)
while len(ret) < nbytes:
buf = self.recv(nbytes - len(ret))
if not buf:
break
ret += buf
return ret
def unrecv(self, buf):
self.buf = buf + self.buf
m = "\033[0;7mur:push\033[0;1;31m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m\n"
print(m.format(buf, self.buf), end="")
Unrecv = _Unrecv
class ProgressPrinter(threading.Thread): class ProgressPrinter(threading.Thread):
""" """
periodically print progress info without linefeeds periodically print progress info without linefeeds
@ -587,19 +645,21 @@ class MultipartParser(object):
yields [fieldname, unsanitized_filename, fieldvalue] yields [fieldname, unsanitized_filename, fieldvalue]
where fieldvalue yields chunks of data where fieldvalue yields chunks of data
""" """
while True: run = True
while run:
fieldname, filename = self._read_header() fieldname, filename = self._read_header()
yield [fieldname, filename, self._read_data()] yield [fieldname, filename, self._read_data()]
tail = self.sr.recv(2) tail = self.sr.recv_ex(2)
if tail == b"--": if tail == b"--":
# EOF indicated by this immediately after final boundary # EOF indicated by this immediately after final boundary
self.sr.recv(2) tail = self.sr.recv_ex(2)
return run = False
if tail != b"\r\n": if tail != b"\r\n":
raise Pebkac(400, "protocol error after field value") m = "protocol error after field value: want b'\\r\\n', got {!r}"
raise Pebkac(400, m.format(tail))
def _read_value(self, iterator, max_len): def _read_value(self, iterator, max_len):
ret = b"" ret = b""
@ -985,8 +1045,12 @@ def read_socket_chunked(sr, log=None):
raise Pebkac(400, err) raise Pebkac(400, err)
if chunklen == 0: if chunklen == 0:
sr.recv(2) # \r\n after final chunk x = sr.recv_ex(2)
return if x == b"\r\n":
return
m = "protocol error after final chunk: want b'\\r\\n', got {!r}"
raise Pebkac(400, m.format(x))
if log: if log:
log("receiving {} byte chunk".format(chunklen)) log("receiving {} byte chunk".format(chunklen))
@ -994,7 +1058,10 @@ def read_socket_chunked(sr, log=None):
for chunk in read_socket(sr, chunklen): for chunk in read_socket(sr, chunklen):
yield chunk yield chunk
sr.recv(2) # \r\n after each chunk too x = sr.recv_ex(2)
if x != b"\r\n":
m = "protocol error in chunk separator: want b'\\r\\n', got {!r}"
raise Pebkac(400, m.format(x))
def yieldfile(fn): def yieldfile(fn):