multipart-parser needs exact reads

This commit is contained in:
ed 2021-09-08 21:07:34 +00:00
parent 5e62532295
commit f5b1a2065e

View file

@ -169,7 +169,7 @@ class Cooldown(object):
return ret
class Unrecv(object):
class _Unrecv(object):
"""
undo any number of socket recv ops
"""
@ -189,10 +189,68 @@ class Unrecv(object):
except:
return b""
def recv_ex(self, nbytes):
"""read an exact number of bytes"""
ret = self.recv(nbytes)
while len(ret) < nbytes:
buf = self.recv(nbytes - len(ret))
if not buf:
break
ret += buf
return ret
def unrecv(self, buf):
self.buf = buf + self.buf
class _LUnrecv(object):
"""
with expensive debug logging
"""
def __init__(self, s):
self.s = s
self.buf = b""
def recv(self, nbytes):
if self.buf:
ret = self.buf[:nbytes]
self.buf = self.buf[nbytes:]
m = "\033[0;7mur:pop:\033[0;1;32m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m\n"
print(m.format(ret, self.buf), end="")
return ret
try:
ret = self.s.recv(nbytes)
m = "\033[0;7mur:recv\033[0;1;33m {}\033[0m\n"
print(m.format(ret), end="")
return ret
except:
return b""
def recv_ex(self, nbytes):
"""read an exact number of bytes"""
ret = self.recv(nbytes)
while len(ret) < nbytes:
buf = self.recv(nbytes - len(ret))
if not buf:
break
ret += buf
return ret
def unrecv(self, buf):
self.buf = buf + self.buf
m = "\033[0;7mur:push\033[0;1;31m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m\n"
print(m.format(buf, self.buf), end="")
Unrecv = _Unrecv
class ProgressPrinter(threading.Thread):
"""
periodically print progress info without linefeeds
@ -587,19 +645,21 @@ class MultipartParser(object):
yields [fieldname, unsanitized_filename, fieldvalue]
where fieldvalue yields chunks of data
"""
while True:
run = True
while run:
fieldname, filename = self._read_header()
yield [fieldname, filename, self._read_data()]
tail = self.sr.recv(2)
tail = self.sr.recv_ex(2)
if tail == b"--":
# EOF indicated by this immediately after final boundary
self.sr.recv(2)
return
tail = self.sr.recv_ex(2)
run = False
if tail != b"\r\n":
raise Pebkac(400, "protocol error after field value")
m = "protocol error after field value: want b'\\r\\n', got {!r}"
raise Pebkac(400, m.format(tail))
def _read_value(self, iterator, max_len):
ret = b""
@ -985,8 +1045,12 @@ def read_socket_chunked(sr, log=None):
raise Pebkac(400, err)
if chunklen == 0:
sr.recv(2) # \r\n after final chunk
return
x = sr.recv_ex(2)
if x == b"\r\n":
return
m = "protocol error after final chunk: want b'\\r\\n', got {!r}"
raise Pebkac(400, m.format(x))
if log:
log("receiving {} byte chunk".format(chunklen))
@ -994,7 +1058,10 @@ def read_socket_chunked(sr, log=None):
for chunk in read_socket(sr, chunklen):
yield chunk
sr.recv(2) # \r\n after each chunk too
x = sr.recv_ex(2)
if x != b"\r\n":
m = "protocol error in chunk separator: want b'\\r\\n', got {!r}"
raise Pebkac(400, m.format(x))
def yieldfile(fn):