diff --git a/copyparty/authsrv.py b/copyparty/authsrv.py index 8ede0610..1468df26 100644 --- a/copyparty/authsrv.py +++ b/copyparty/authsrv.py @@ -372,6 +372,7 @@ class VFS(object): self.shr_src: Optional[tuple[VFS, str]] = None # source vfs+rem of a share self.shr_files: set[str] = set() # filenames to include from shr_src self.shr_owner: str = "" # uname + self.shr_all_aps: list[tuple[str, list[VFS]]] = [] self.aread: dict[str, list[str]] = {} self.awrite: dict[str, list[str]] = {} self.amove: dict[str, list[str]] = {} @@ -391,7 +392,7 @@ class VFS(object): self.dbpath = self.histpath self.all_vols = {vpath: self} # flattened recursive self.all_nodes = {vpath: self} # also jumpvols/shares - self.all_aps = [(rp, self)] + self.all_aps = [(rp, [self])] self.all_vps = [(vp, self)] else: self.histpath = self.dbpath = "" @@ -415,7 +416,7 @@ class VFS(object): self, vols: dict[str, "VFS"], nodes: dict[str, "VFS"], - aps: list[tuple[str, "VFS"]], + aps: list[tuple[str, list["VFS"]]], vps: list[tuple[str, "VFS"]], ) -> None: nodes[self.vpath] = self @@ -424,7 +425,11 @@ class VFS(object): rp = self.realpath rp += "" if rp.endswith(os.sep) else os.sep vp = self.vpath + ("/" if self.vpath else "") - aps.append((rp, self)) + hit = next((x[1] for x in aps if x[0] == rp), None) + if hit: + hit.append(self) + else: + aps.append((rp, [self])) vps.append((vp, self)) for v in self.nodes.values(): @@ -848,9 +853,11 @@ class VFS(object): return None if "xvol" in self.flags: - for vap, vn in self.root.all_aps: + all_aps = self.shr_all_aps or self.root.all_aps + + for vap, vns in all_aps: if aps.startswith(vap): - return vn + return self if self in vns else vns[0] if self.log: self.log("vfs", "xvol: %r" % (ap,), 3) @@ -2554,6 +2561,28 @@ class AuthSrv(object): shn.shr_src = (s_vfs, s_rem) shn.realpath = s_vfs.canonical(s_rem) + # root.all_aps doesn't include any shares, so make a copy where the + # share appears in all abspaths it can provide (for example for chk_ap) + ap = shn.realpath + if not ap.endswith(os.sep): + ap += os.sep + shn.shr_all_aps = [(x, y[:]) for x, y in vfs.all_aps] + exact = False + for ap2, vns in shn.shr_all_aps: + if ap == ap2: + exact = True + if ap2.startswith(ap): + try: + vp2 = vjoin(s_rem, ap2[len(ap) :]) + vn2, _ = s_vfs.get(vp2, "*", False, False) + if vn2 == s_vfs or vn2.dbv == s_vfs: + vns.append(shn) + except: + pass + if not exact: + shn.shr_all_aps.append((ap, [shn])) + shn.shr_all_aps.sort(key=lambda x: len(x[0]), reverse=True) + if self.args.shr_v: t = "mapped %s share [%s] by [%s] => [%s] => [%s]" self.log(t % (s_pr, s_k, s_un, s_vp, shn.realpath)) diff --git a/copyparty/th_srv.py b/copyparty/th_srv.py index 59c440e9..c6b4a94f 100644 --- a/copyparty/th_srv.py +++ b/copyparty/th_srv.py @@ -284,7 +284,7 @@ class ThumbSrv(object): vn = next((x for x in allvols if x.realpath == ptop), None) if not vn: self.log("ptop %r not in %s" % (ptop, allvols), 3) - vn = self.asrv.vfs.all_aps[0][1] + vn = self.asrv.vfs.all_aps[0][1][0] self.q.put((abspath, tpath, fmt, vn)) self.log("conv %r :%s \033[0m%r" % (tpath, fmt, abspath), 6) diff --git a/copyparty/util.py b/copyparty/util.py index 6d9c0210..f877ea9d 100644 --- a/copyparty/util.py +++ b/copyparty/util.py @@ -2400,11 +2400,11 @@ def pathmod( # try to map abspath to vpath np = np.replace("/", os.sep) - for vn_ap, vn in vfs.all_aps: + for vn_ap, vns in vfs.all_aps: if not np.startswith(vn_ap): continue zs = np[len(vn_ap) :].replace(os.sep, "/") - nvp = vjoin(vn.vpath, zs) + nvp = vjoin(vns[0].vpath, zs) break if nvp == "\n": diff --git a/tests/test_shr.py b/tests/test_shr.py new file mode 100644 index 00000000..7ecfb388 --- /dev/null +++ b/tests/test_shr.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python3 +# coding: utf-8 +from __future__ import print_function, unicode_literals + +import json +import os +import shutil +import sqlite3 +import tempfile +import unittest + +from copyparty.__init__ import ANYWIN +from copyparty.authsrv import AuthSrv +from copyparty.httpcli import HttpCli +from copyparty.util import absreal +from tests import util as tu +from tests.util import Cfg + + +class TestShr(unittest.TestCase): + def log(self, src, msg, c=0): + m = "%s" % (msg,) + if ( + "warning: filesystem-path does not exist:" in m + or "you are sharing a system directory:" in m + or "symlink-based deduplication is enabled" in m + or m.startswith("hint: argument") + ): + return + + print(("[%s] %s" % (src, msg)).encode("ascii", "replace").decode("ascii")) + + def assertLD(self, url, auth, els, edl): + ls = self.ls(url, auth) + self.assertEqual(ls[0], len(els) == 2) + if not ls[0]: + return + a = [list(sorted(els[0])), list(sorted(els[1]))] + b = [list(sorted(ls[1])), list(sorted(ls[2]))] + self.assertEqual(a, b) + + if edl is None: + edl = els[1] + can_dl = [] + for fn in b[1]: + if fn == "a.db": + continue + furl = url + "/" + fn + if auth: + furl += "?pw=p1" + h, zb = self.curl(furl, True) + if h.startswith("HTTP/1.1 200 "): + can_dl.append(fn) + self.assertEqual(edl, can_dl) + + def setUp(self): + self.td = tu.get_ramdisk() + td = os.path.join(self.td, "vfs") + os.mkdir(td) + os.chdir(td) + os.mkdir("d1") + os.mkdir("d2") + os.mkdir("d2/d3") + for zs in ("d1/f1", "d2/f2", "d2/d3/f3"): + with open(zs, "wb") as f: + f.write(zs.encode("utf-8")) + for dst in ("d1", "d2", "d2/d3"): + src, fn = zs.rsplit("/", 1) + os.symlink(absreal(zs), dst + "/l" + fn[-1:]) + + db = sqlite3.connect("a.db") + with db: + zs = r"create table sh (k text, pw text, vp text, pr text, st int, un text, t0 int, t1 int)" + db.execute(zs) + db.close() + + def tearDown(self): + os.chdir(tempfile.gettempdir()) + shutil.rmtree(self.td) + + def cinit(self): + self.asrv = AuthSrv(self.args, self.log) + self.conn = tu.VHttpConn(self.args, self.asrv, self.log, b"", True) + + def test1(self): + self.args = Cfg( + a=["u1:p1"], + v=["::A,u1", "d1:v1:A,u1", "d2/d3:d2/d3:A,u1"], + shr="/shr/", + shr1="shr/", + shr_db="a.db", + shr_v=False, + ) + self.cinit() + + self.assertLD("", True, [["d1", "d2", "v1"], ["a.db"]], []) + self.assertLD("d1", True, [[], ["f1", "l1", "l2", "l3"]], None) + self.assertLD("v1", True, [[], ["f1", "l1", "l2", "l3"]], None) + self.assertLD("d2", True, [["d3"], ["f2", "l1", "l2", "l3"]], None) + self.assertLD("d2/d3", True, [[], ["f3", "l1", "l2", "l3"]], None) + self.assertLD("d3", True, [], []) + + jt = { + "k": "r", + "vp": ["/"], + "pw": "", + "exp": "99", + "perms": ["read"], + } + print(self.post_json("?pw=p1&share", jt)[1]) + jt = { + "k": "d2", + "vp": ["/d2/"], + "pw": "", + "exp": "99", + "perms": ["read"], + } + print(self.post_json("?pw=p1&share", jt)[1]) + self.conn.shutdown() + self.cinit() + + self.assertLD("", True, [["d1", "d2", "v1"], ["a.db"]], []) + self.assertLD("d1", True, [[], ["f1", "l1", "l2", "l3"]], None) + self.assertLD("v1", True, [[], ["f1", "l1", "l2", "l3"]], None) + self.assertLD("d2", True, [["d3"], ["f2", "l1", "l2", "l3"]], None) + self.assertLD("d2/d3", True, [[], ["f3", "l1", "l2", "l3"]], None) + self.assertLD("d3", True, [], []) + + self.assertLD("shr/d2", False, [[], ["f2", "l1", "l2", "l3"]], None) + self.assertLD("shr/d2/d3", False, [], None) + + self.assertLD("shr/r", False, [["d1"], ["a.db"]], []) + self.assertLD("shr/r/d1", False, [[], ["f1", "l1", "l2", "l3"]], None) + self.assertLD("shr/r/d2", False, [], None) # unfortunate + self.assertLD("shr/r/d2/d3", False, [], None) + + self.conn.shutdown() + + def test2(self): + self.args = Cfg( + a=["u1:p1"], + v=["::A,u1", "d1:v1:A,u1", "d2/d3:d2/d3:A,u1"], + shr="/shr/", + shr1="shr/", + shr_db="a.db", + shr_v=False, + xvol=True, + ) + self.cinit() + + self.assertLD("", True, [["d1", "d2", "v1"], ["a.db"]], []) + self.assertLD("d1", True, [[], ["f1", "l1", "l2", "l3"]], None) + self.assertLD("v1", True, [[], ["f1", "l1", "l2", "l3"]], None) + self.assertLD("d2", True, [["d3"], ["f2", "l1", "l2", "l3"]], None) + self.assertLD("d2/d3", True, [[], ["f3", "l1", "l2", "l3"]], None) + self.assertLD("d3", True, [], []) + + jt = { + "k": "r", + "vp": ["/"], + "pw": "", + "exp": "99", + "perms": ["read"], + } + print(self.post_json("?pw=p1&share", jt)[1]) + jt = { + "k": "d2", + "vp": ["/d2/"], + "pw": "", + "exp": "99", + "perms": ["read"], + } + print(self.post_json("?pw=p1&share", jt)[1]) + self.conn.shutdown() + self.cinit() + + self.assertLD("", True, [["d1", "d2", "v1"], ["a.db"]], []) + self.assertLD("d1", True, [[], ["f1", "l1", "l2", "l3"]], None) + self.assertLD("v1", True, [[], ["f1", "l1", "l2", "l3"]], None) + self.assertLD("d2", True, [["d3"], ["f2", "l1", "l2", "l3"]], None) + self.assertLD("d2/d3", True, [[], ["f3", "l1", "l2", "l3"]], None) + self.assertLD("d3", True, [], []) + + self.assertLD("shr/d2", False, [[], ["f2", "l1", "l2", "l3"]], ["f2", "l2"]) + self.assertLD("shr/d2/d3", False, [], []) + + self.assertLD("shr/r", False, [["d1"], ["a.db"]], []) + self.assertLD( + "shr/r/d1", False, [[], ["f1", "l1", "l2", "l3"]], ["f1", "l1", "l2"] + ) + self.assertLD("shr/r/d2", False, [], []) # unfortunate + self.assertLD("shr/r/d2/d3", False, [], []) + + self.conn.shutdown() + + def ls(self, url: str, auth: bool): + zs = url + "?ls" + ("&pw=p1" if auth else "") + h, b = self.curl(zs) + if not h.startswith("HTTP/1.1 200 "): + return (False, [], []) + jo = json.loads(b) + return ( + True, + [x["href"].rstrip("/") for x in jo.get("dirs") or {}], + [x["href"] for x in jo.get("files") or {}], + ) + + def curl(self, url: str, binary=False): + h = "GET /%s HTTP/1.1\r\nConnection: close\r\n\r\n" + HttpCli(self.conn.setbuf((h % (url,)).encode("utf-8"))).run() + if binary: + h, b = self.conn.s._reply.split(b"\r\n\r\n", 1) + return [h.decode("utf-8"), b] + + return self.conn.s._reply.decode("utf-8").split("\r\n\r\n", 1) + + def post_json(self, url: str, data): + buf = json.dumps(data).encode("utf-8") + msg = [ + "POST /%s HTTP/1.1" % (url,), + "Connection: close", + "Content-Type: application/json", + "Content-Length: %d" % (len(buf),), + "\r\n", + ] + buf = "\r\n".join(msg).encode("utf-8") + buf + print("PUT -->", buf) + HttpCli(self.conn.setbuf(buf)).run() + return self.conn.s._reply.decode("utf-8").split("\r\n\r\n", 1) diff --git a/tests/util.py b/tests/util.py index 427bdf3d..12277adf 100644 --- a/tests/util.py +++ b/tests/util.py @@ -260,6 +260,9 @@ class VHub(object): self.is_dut = True self.up2k = Up2k(self) + def reload(self, a, b): + pass + class VBrokerThr(BrokerThr): def __init__(self, hub):