fix write-only volumes + add regression test

This commit is contained in:
ed 2021-04-24 02:48:41 +02:00
parent 127ec10c0d
commit ed7727f7cb
6 changed files with 396 additions and 118 deletions

View file

@ -182,10 +182,8 @@ class HttpCli(object):
self.out_headers.update(headers) self.out_headers.update(headers)
# default to utf8 html if no content-type is set # default to utf8 html if no content-type is set
try: if not mime:
mime = mime or self.out_headers["Content-Type"] mime = self.out_headers.get("Content-Type", "text/html; charset=UTF-8")
except KeyError:
mime = "text/html; charset=UTF-8"
self.out_headers["Content-Type"] = mime self.out_headers["Content-Type"] = mime
@ -536,7 +534,7 @@ class HttpCli(object):
self.log("qj: " + repr(vbody)) self.log("qj: " + repr(vbody))
hits = idx.fsearch(vols, body) hits = idx.fsearch(vols, body)
msg = repr(hits) msg = repr(hits)
taglist = [] taglist = {}
else: else:
# search by query params # search by query params
self.log("qj: " + repr(body)) self.log("qj: " + repr(body))
@ -1319,6 +1317,74 @@ class HttpCli(object):
# print(abspath) # print(abspath)
raise Pebkac(404) raise Pebkac(404)
srv_info = []
try:
if not self.args.nih:
srv_info.append(unicode(socket.gethostname()).split(".")[0])
except:
self.log("#wow #whoa")
try:
# some fuses misbehave
if not self.args.nid:
if WINDOWS:
bfree = ctypes.c_ulonglong(0)
ctypes.windll.kernel32.GetDiskFreeSpaceExW(
ctypes.c_wchar_p(abspath), None, None, ctypes.pointer(bfree)
)
srv_info.append(humansize(bfree.value) + " free")
else:
sv = os.statvfs(abspath)
free = humansize(sv.f_frsize * sv.f_bfree, True)
total = humansize(sv.f_frsize * sv.f_blocks, True)
srv_info.append(free + " free")
srv_info.append(total)
except:
pass
srv_info = "</span> /// <span>".join(srv_info)
perms = []
if self.readable:
perms.append("read")
if self.writable:
perms.append("write")
url_suf = self.urlq()
is_ls = "ls" in self.uparam
ts = "" # "?{}".format(time.time())
tpl = "browser"
if "b" in self.uparam:
tpl = "browser2"
j2a = {
"vdir": quotep(self.vpath),
"vpnodes": vpnodes,
"files": [],
"ts": ts,
"perms": json.dumps(perms),
"taglist": [],
"tag_order": [],
"have_up2k_idx": ("e2d" in vn.flags),
"have_tags_idx": ("e2t" in vn.flags),
"have_zip": (not self.args.no_zip),
"have_b_u": (self.writable and self.uparam.get("b") == "u"),
"url_suf": url_suf,
"logues": ["", ""],
"title": html_escape(self.vpath, crlf=True),
"srv_info": srv_info,
}
if not self.readable:
if is_ls:
raise Pebkac(403)
html = self.j2(tpl, **j2a)
self.reply(html.encode("utf-8", "replace"))
return True
if not os.path.isdir(fsenc(abspath)): if not os.path.isdir(fsenc(abspath)):
if abspath.endswith(".md") and "raw" not in self.uparam: if abspath.endswith(".md") and "raw" not in self.uparam:
return self.tx_md(abspath) return self.tx_md(abspath)
@ -1362,15 +1428,11 @@ class HttpCli(object):
if rem == ".hist": if rem == ".hist":
hidden = ["up2k."] hidden = ["up2k."]
is_ls = "ls" in self.uparam
icur = None icur = None
if "e2t" in vn.flags: if "e2t" in vn.flags:
idx = self.conn.get_u2idx() idx = self.conn.get_u2idx()
icur = idx.get_cur(vn.realpath) icur = idx.get_cur(vn.realpath)
url_suf = self.urlq()
dirs = [] dirs = []
files = [] files = []
for fn in vfs_ls: for fn in vfs_ls:
@ -1461,42 +1523,6 @@ class HttpCli(object):
for f in dirs: for f in dirs:
f["tags"] = {} f["tags"] = {}
srv_info = []
try:
if not self.args.nih:
srv_info.append(unicode(socket.gethostname()).split(".")[0])
except:
self.log("#wow #whoa")
pass
try:
# some fuses misbehave
if not self.args.nid:
if WINDOWS:
bfree = ctypes.c_ulonglong(0)
ctypes.windll.kernel32.GetDiskFreeSpaceExW(
ctypes.c_wchar_p(abspath), None, None, ctypes.pointer(bfree)
)
srv_info.append(humansize(bfree.value) + " free")
else:
sv = os.statvfs(abspath)
free = humansize(sv.f_frsize * sv.f_bfree, True)
total = humansize(sv.f_frsize * sv.f_blocks, True)
srv_info.append(free + " free")
srv_info.append(total)
except:
pass
srv_info = "</span> /// <span>".join(srv_info)
perms = []
if self.readable:
perms.append("read")
if self.writable:
perms.append("write")
logues = ["", ""] logues = ["", ""]
for n, fn in enumerate([".prologue.html", ".epilogue.html"]): for n, fn in enumerate([".prologue.html", ".epilogue.html"]):
fn = os.path.join(abspath, fn) fn = os.path.join(abspath, fn)
@ -1518,34 +1544,12 @@ class HttpCli(object):
self.reply(ret.encode("utf-8", "replace"), mime="application/json") self.reply(ret.encode("utf-8", "replace"), mime="application/json")
return True return True
ts = "" j2a["files"] = dirs + files
# ts = "?{}".format(time.time()) j2a["logues"] = logues
j2a["taglist"] = taglist
if "mte" in vn.flags:
j2a["tag_order"] = json.dumps(vn.flags["mte"].split(","))
dirs.extend(files) html = self.j2(tpl, **j2a)
tpl = "browser"
if "b" in self.uparam:
tpl = "browser2"
html = self.j2(
tpl,
vdir=quotep(self.vpath),
vpnodes=vpnodes,
files=dirs,
ts=ts,
perms=json.dumps(perms),
taglist=taglist,
tag_order=json.dumps(
vn.flags["mte"].split(",") if "mte" in vn.flags else []
),
have_up2k_idx=("e2d" in vn.flags),
have_tags_idx=("e2t" in vn.flags),
have_zip=(not self.args.no_zip),
have_b_u=(self.writable and self.uparam.get("b") == "u"),
url_suf=url_suf,
logues=logues,
title=html_escape(self.vpath, crlf=True),
srv_info=srv_info,
)
self.reply(html.encode("utf-8", "replace")) self.reply(html.encode("utf-8", "replace"))
return True return True

0
tests/__init__.py Normal file
View file

33
tests/run.py Executable file
View file

@ -0,0 +1,33 @@
#!/usr/bin/env python3
import sys
import runpy
host = sys.argv[1]
sys.argv = sys.argv[:1] + sys.argv[2:]
sys.path.insert(0, ".")
def rp():
runpy.run_module("unittest", run_name="__main__")
if host == "vmprof":
rp()
elif host == "cprofile":
import cProfile
import pstats
log_fn = "cprofile.log"
cProfile.run("rp()", log_fn)
p = pstats.Stats(log_fn)
p.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(64)
"""
python3.9 tests/run.py cprofile -v tests/test_httpcli.py
python3.9 -m pip install --user vmprof
python3.9 -m vmprof --lines -o vmprof.log tests/run.py vmprof -v tests/test_httpcli.py
"""

192
tests/test_httpcli.py Normal file
View file

@ -0,0 +1,192 @@
#!/usr/bin/env python
# coding: utf-8
from __future__ import print_function, unicode_literals
import io
import os
import time
import shutil
import pprint
import tarfile
import unittest
from argparse import Namespace
from copyparty.authsrv import AuthSrv
from copyparty.httpcli import HttpCli
from tests import util as tu
def hdr(query):
h = "GET /{} HTTP/1.1\r\nCookie: cppwd=o\r\nConnection: close\r\n\r\n"
return h.format(query).encode("utf-8")
class Cfg(Namespace):
def __init__(self, a=[], v=[], c=None):
super(Cfg, self).__init__(
a=a,
v=v,
c=c,
ed=False,
no_zip=False,
no_scandir=False,
no_sendfile=True,
nih=True,
mtp=[],
mte="a",
**{k: False for k in "e2d e2ds e2dsa e2t e2ts e2tsr".split()}
)
class TestHttpCli(unittest.TestCase):
def test(self):
td = os.path.join(tu.get_ramdisk(), "vfs")
try:
shutil.rmtree(td)
except OSError:
pass
os.mkdir(td)
os.chdir(td)
self.dtypes = ["ra", "ro", "rx", "wa", "wo", "wx", "aa", "ao", "ax"]
self.can_read = ["ra", "ro", "aa", "ao"]
self.can_write = ["wa", "wo", "aa", "ao"]
self.fn = "g{:x}g".format(int(time.time() * 3))
allfiles = []
allvols = []
for top in self.dtypes:
allvols.append(top)
allfiles.append("/".join([top, self.fn]))
for s1 in self.dtypes:
p = "/".join([top, s1])
allvols.append(p)
allfiles.append(p + "/" + self.fn)
allfiles.append(p + "/n/" + self.fn)
for s2 in self.dtypes:
p = "/".join([top, s1, "n", s2])
os.makedirs(p)
allvols.append(p)
allfiles.append(p + "/" + self.fn)
for fp in allfiles:
with open(fp, "w") as f:
f.write("ok {}\n".format(fp))
for top in self.dtypes:
vcfg = []
for vol in allvols:
if not vol.startswith(top):
continue
mode = vol[-2]
usr = vol[-1]
if usr == "a":
usr = ""
if "/" not in vol:
vol += "/"
top, sub = vol.split("/", 1)
vcfg.append("{0}/{1}:{1}:{2}{3}".format(top, sub, mode, usr))
pprint.pprint(vcfg)
self.args = Cfg(v=vcfg, a=["o:o", "x:x"])
self.auth = AuthSrv(self.args, self.log)
vfiles = [x for x in allfiles if x.startswith(top)]
for fp in vfiles:
rok, wok = self.can_rw(fp)
furl = fp.split("/", 1)[1]
durl = furl.rsplit("/", 1)[0] if "/" in furl else ""
# file download
ret = self.curl(furl)
res = "ok " + fp in ret
print("[{}] {} {} = {}".format(fp, rok, wok, res))
if rok != res:
print("\033[33m{}\n# {}\033[0m".format(ret, furl))
self.fail()
# file browser: html
ret = self.curl(durl)
res = "'{}'".format(self.fn) in ret
print(res)
if rok != res:
print("\033[33m{}\n# {}\033[0m".format(ret, durl))
self.fail()
# file browser: json
url = durl + "?ls"
ret = self.curl(url)
res = '"{}"'.format(self.fn) in ret
print(res)
if rok != res:
print("\033[33m{}\n# {}\033[0m".format(ret, url))
self.fail()
# tar
url = durl + "?tar"
h, b = self.curl(url, True)
# with open(os.path.join(td, "tar"), "wb") as f:
# f.write(b)
try:
tar = tarfile.open(fileobj=io.BytesIO(b)).getnames()
except:
tar = []
tar = ["/".join([y for y in [top, durl, x] if y]) for x in tar]
tar = [[x] + self.can_rw(x) for x in tar]
tar_ok = [x[0] for x in tar if x[1]]
tar_ng = [x[0] for x in tar if not x[1]]
self.assertEqual([], tar_ng)
if durl.split("/")[-1] not in self.can_read:
continue
ref = [x for x in vfiles if self.in_dive(top + "/" + durl, x)]
for f in ref:
print("{}: {}".format("ok" if f in tar_ok else "NG", f))
ref.sort()
tar_ok.sort()
self.assertEqual(ref, tar_ok)
def can_rw(self, fp):
# lowest non-neutral folder declares permissions
expect = fp.split("/")[:-1]
for x in reversed(expect):
if x != "n":
expect = x
break
return [expect in self.can_read, expect in self.can_write]
def in_dive(self, top, fp):
# archiver bails at first inaccessible subvolume
top = top.strip("/").split("/")
fp = fp.split("/")
for f1, f2 in zip(top, fp):
if f1 != f2:
return False
for f in fp[len(top) :]:
if f == self.fn:
return True
if f not in self.can_read and f != "n":
return False
return True
def curl(self, url, binary=False):
conn = tu.VHttpConn(self.args, self.auth, self.log, hdr(url))
HttpCli(conn).run()
if binary:
h, b = conn.s._reply.split(b"\r\n\r\n", 1)
return h, b
return conn.s._reply.decode("utf-8")
def log(self, src, msg, c=0):
# print(repr(msg))
pass

View file

@ -3,18 +3,18 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import os import os
import time
import json import json
import shutil import shutil
import tempfile import tempfile
import unittest import unittest
import subprocess as sp # nosec
from textwrap import dedent from textwrap import dedent
from argparse import Namespace from argparse import Namespace
from copyparty.authsrv import AuthSrv from copyparty.authsrv import AuthSrv
from copyparty import util from copyparty import util
from tests import util as tu
class Cfg(Namespace): class Cfg(Namespace):
def __init__(self, a=[], v=[], c=None): def __init__(self, a=[], v=[], c=None):
@ -51,52 +51,11 @@ class TestVFS(unittest.TestCase):
real = [x[0] for x in real] real = [x[0] for x in real]
return fsdir, real, virt return fsdir, real, virt
def runcmd(self, *argv):
p = sp.Popen(argv, stdout=sp.PIPE, stderr=sp.PIPE)
stdout, stderr = p.communicate()
stdout = stdout.decode("utf-8")
stderr = stderr.decode("utf-8")
return [p.returncode, stdout, stderr]
def chkcmd(self, *argv):
ok, sout, serr = self.runcmd(*argv)
if ok != 0:
raise Exception(serr)
return sout, serr
def get_ramdisk(self):
for vol in ["/dev/shm", "/Volumes/cptd"]: # nosec (singleton test)
if os.path.exists(vol):
return vol
if os.path.exists("/Volumes"):
devname, _ = self.chkcmd("hdiutil", "attach", "-nomount", "ram://8192")
devname = devname.strip()
print("devname: [{}]".format(devname))
for _ in range(10):
try:
_, _ = self.chkcmd(
"diskutil", "eraseVolume", "HFS+", "cptd", devname
)
return "/Volumes/cptd"
except Exception as ex:
print(repr(ex))
time.sleep(0.25)
raise Exception("ramdisk creation failed")
ret = os.path.join(tempfile.gettempdir(), "copyparty-test")
try:
os.mkdir(ret)
finally:
return ret
def log(self, src, msg, c=0): def log(self, src, msg, c=0):
pass pass
def test(self): def test(self):
td = os.path.join(self.get_ramdisk(), "vfs") td = os.path.join(tu.get_ramdisk(), "vfs")
try: try:
shutil.rmtree(td) shutil.rmtree(td)
except OSError: except OSError:
@ -268,7 +227,7 @@ class TestVFS(unittest.TestCase):
self.assertEqual(list(v1), list(v2)) self.assertEqual(list(v1), list(v2))
# config file parser # config file parser
cfg_path = os.path.join(self.get_ramdisk(), "test.cfg") cfg_path = os.path.join(tu.get_ramdisk(), "test.cfg")
with open(cfg_path, "wb") as f: with open(cfg_path, "wb") as f:
f.write( f.write(
dedent( dedent(

90
tests/util.py Normal file
View file

@ -0,0 +1,90 @@
import os
import time
import jinja2
import tempfile
import subprocess as sp
from copyparty.util import Unrecv
J2_ENV = jinja2.Environment(loader=jinja2.BaseLoader)
J2_FILES = J2_ENV.from_string("{{ files|join('\n') }}")
def runcmd(*argv):
p = sp.Popen(argv, stdout=sp.PIPE, stderr=sp.PIPE)
stdout, stderr = p.communicate()
stdout = stdout.decode("utf-8")
stderr = stderr.decode("utf-8")
return [p.returncode, stdout, stderr]
def chkcmd(*argv):
ok, sout, serr = runcmd(*argv)
if ok != 0:
raise Exception(serr)
return sout, serr
def get_ramdisk():
for vol in ["/dev/shm", "/Volumes/cptd"]: # nosec (singleton test)
if os.path.exists(vol):
return vol
if os.path.exists("/Volumes"):
devname, _ = chkcmd("hdiutil", "attach", "-nomount", "ram://32768")
devname = devname.strip()
print("devname: [{}]".format(devname))
for _ in range(10):
try:
_, _ = chkcmd("diskutil", "eraseVolume", "HFS+", "cptd", devname)
return "/Volumes/cptd"
except Exception as ex:
print(repr(ex))
time.sleep(0.25)
raise Exception("ramdisk creation failed")
ret = os.path.join(tempfile.gettempdir(), "copyparty-test")
try:
os.mkdir(ret)
finally:
return ret
class VSock(object):
def __init__(self, buf):
self._query = buf
self._reply = b""
self.sendall = self.send
def recv(self, sz):
ret = self._query[:sz]
self._query = self._query[sz:]
return ret
def send(self, buf):
self._reply += buf
return len(buf)
class VHttpSrv(object):
def __init__(self):
aliases = ["splash", "browser", "browser2", "msg", "md", "mde"]
self.j2 = {x: J2_FILES for x in aliases}
class VHttpConn(object):
def __init__(self, args, auth, log, buf):
self.s = VSock(buf)
self.sr = Unrecv(self.s)
self.addr = ("127.0.0.1", "42069")
self.args = args
self.auth = auth
self.log_func = log
self.log_src = "a"
self.hsrv = VHttpSrv()
self.nbyte = 0
self.t0 = time.time()