mirror of
https://github.com/9001/copyparty.git
synced 2025-08-17 09:02:15 -06:00
multipart-parser needs exact reads
This commit is contained in:
parent
5e62532295
commit
f5b1a2065e
|
@ -169,7 +169,7 @@ class Cooldown(object):
|
|||
return ret
|
||||
|
||||
|
||||
class Unrecv(object):
|
||||
class _Unrecv(object):
|
||||
"""
|
||||
undo any number of socket recv ops
|
||||
"""
|
||||
|
@ -189,10 +189,68 @@ class Unrecv(object):
|
|||
except:
|
||||
return b""
|
||||
|
||||
def recv_ex(self, nbytes):
|
||||
"""read an exact number of bytes"""
|
||||
ret = self.recv(nbytes)
|
||||
while len(ret) < nbytes:
|
||||
buf = self.recv(nbytes - len(ret))
|
||||
if not buf:
|
||||
break
|
||||
|
||||
ret += buf
|
||||
|
||||
return ret
|
||||
|
||||
def unrecv(self, buf):
|
||||
self.buf = buf + self.buf
|
||||
|
||||
|
||||
class _LUnrecv(object):
|
||||
"""
|
||||
with expensive debug logging
|
||||
"""
|
||||
|
||||
def __init__(self, s):
|
||||
self.s = s
|
||||
self.buf = b""
|
||||
|
||||
def recv(self, nbytes):
|
||||
if self.buf:
|
||||
ret = self.buf[:nbytes]
|
||||
self.buf = self.buf[nbytes:]
|
||||
m = "\033[0;7mur:pop:\033[0;1;32m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m\n"
|
||||
print(m.format(ret, self.buf), end="")
|
||||
return ret
|
||||
|
||||
try:
|
||||
ret = self.s.recv(nbytes)
|
||||
m = "\033[0;7mur:recv\033[0;1;33m {}\033[0m\n"
|
||||
print(m.format(ret), end="")
|
||||
return ret
|
||||
except:
|
||||
return b""
|
||||
|
||||
def recv_ex(self, nbytes):
|
||||
"""read an exact number of bytes"""
|
||||
ret = self.recv(nbytes)
|
||||
while len(ret) < nbytes:
|
||||
buf = self.recv(nbytes - len(ret))
|
||||
if not buf:
|
||||
break
|
||||
|
||||
ret += buf
|
||||
|
||||
return ret
|
||||
|
||||
def unrecv(self, buf):
|
||||
self.buf = buf + self.buf
|
||||
m = "\033[0;7mur:push\033[0;1;31m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m\n"
|
||||
print(m.format(buf, self.buf), end="")
|
||||
|
||||
|
||||
Unrecv = _Unrecv
|
||||
|
||||
|
||||
class ProgressPrinter(threading.Thread):
|
||||
"""
|
||||
periodically print progress info without linefeeds
|
||||
|
@ -587,19 +645,21 @@ class MultipartParser(object):
|
|||
yields [fieldname, unsanitized_filename, fieldvalue]
|
||||
where fieldvalue yields chunks of data
|
||||
"""
|
||||
while True:
|
||||
run = True
|
||||
while run:
|
||||
fieldname, filename = self._read_header()
|
||||
yield [fieldname, filename, self._read_data()]
|
||||
|
||||
tail = self.sr.recv(2)
|
||||
tail = self.sr.recv_ex(2)
|
||||
|
||||
if tail == b"--":
|
||||
# EOF indicated by this immediately after final boundary
|
||||
self.sr.recv(2)
|
||||
return
|
||||
tail = self.sr.recv_ex(2)
|
||||
run = False
|
||||
|
||||
if tail != b"\r\n":
|
||||
raise Pebkac(400, "protocol error after field value")
|
||||
m = "protocol error after field value: want b'\\r\\n', got {!r}"
|
||||
raise Pebkac(400, m.format(tail))
|
||||
|
||||
def _read_value(self, iterator, max_len):
|
||||
ret = b""
|
||||
|
@ -985,8 +1045,12 @@ def read_socket_chunked(sr, log=None):
|
|||
raise Pebkac(400, err)
|
||||
|
||||
if chunklen == 0:
|
||||
sr.recv(2) # \r\n after final chunk
|
||||
return
|
||||
x = sr.recv_ex(2)
|
||||
if x == b"\r\n":
|
||||
return
|
||||
|
||||
m = "protocol error after final chunk: want b'\\r\\n', got {!r}"
|
||||
raise Pebkac(400, m.format(x))
|
||||
|
||||
if log:
|
||||
log("receiving {} byte chunk".format(chunklen))
|
||||
|
@ -994,7 +1058,10 @@ def read_socket_chunked(sr, log=None):
|
|||
for chunk in read_socket(sr, chunklen):
|
||||
yield chunk
|
||||
|
||||
sr.recv(2) # \r\n after each chunk too
|
||||
x = sr.recv_ex(2)
|
||||
if x != b"\r\n":
|
||||
m = "protocol error in chunk separator: want b'\\r\\n', got {!r}"
|
||||
raise Pebkac(400, m.format(x))
|
||||
|
||||
|
||||
def yieldfile(fn):
|
||||
|
|
Loading…
Reference in a new issue