detect recursive symlinks

This commit is contained in:
ed 2021-06-07 20:09:18 +02:00
parent ae90a7b7b6
commit 89e48cff24
6 changed files with 35 additions and 18 deletions

View file

@ -238,6 +238,7 @@ def run_argparse(argv, formatter):
--ls '**' # list all files which are possible to read --ls '**' # list all files which are possible to read
--ls '**,*,ln' # check for dangerous symlinks --ls '**,*,ln' # check for dangerous symlinks
--ls '**,*,ln,p,r' # check, then start normally if safe --ls '**,*,ln,p,r' # check, then start normally if safe
\033[0m
""" """
), ),
) )

View file

@ -7,7 +7,7 @@ import sys
import stat import stat
import threading import threading
from .__init__ import PY2, WINDOWS from .__init__ import WINDOWS
from .util import IMPLICATIONS, undot, Pebkac, fsdec, fsenc, statdir, nuprint from .util import IMPLICATIONS, undot, Pebkac, fsdec, fsenc, statdir, nuprint
@ -96,6 +96,7 @@ class VFS(object):
] ]
def get(self, vpath, uname, will_read, will_write): def get(self, vpath, uname, will_read, will_write):
# type: (str, str, bool, bool) -> tuple[VFS, str]
"""returns [vfsnode,fs_remainder] if user has the requested permissions""" """returns [vfsnode,fs_remainder] if user has the requested permissions"""
vn, rem = self._find(vpath) vn, rem = self._find(vpath)
@ -136,6 +137,7 @@ class VFS(object):
return os.path.realpath(rp) return os.path.realpath(rp)
def ls(self, rem, uname, scandir, incl_wo=False, lstat=False): def ls(self, rem, uname, scandir, incl_wo=False, lstat=False):
# type: (str, str, bool, bool, bool) -> tuple[str, str, dict[str, VFS]]
"""return user-readable [fsdir,real,virt] items at vpath""" """return user-readable [fsdir,real,virt] items at vpath"""
virt_vis = {} # nodes readable by user virt_vis = {} # nodes readable by user
abspath = self.canonical(rem) abspath = self.canonical(rem)
@ -156,13 +158,21 @@ class VFS(object):
return [abspath, real, virt_vis] return [abspath, real, virt_vis]
def walk(self, rel, rem, uname, dots, scandir, lstat=False): def walk(self, rel, rem, seen, uname, dots, scandir, lstat):
""" """
recursively yields from ./rem; recursively yields from ./rem;
rel is a unix-style user-defined vpath (not vfs-related) rel is a unix-style user-defined vpath (not vfs-related)
""" """
fsroot, vfs_ls, vfs_virt = self.ls(rem, uname, scandir, False, lstat) fsroot, vfs_ls, vfs_virt = self.ls(
rem, uname, scandir, incl_wo=False, lstat=lstat
)
if seen and not fsroot.startswith(seen[-1]) and fsroot in seen:
print("bailing from symlink loop,\n {}\n {}".format(seen[-1], fsroot))
return
seen = seen[:] + [fsroot]
rfiles = [x for x in vfs_ls if not stat.S_ISDIR(x[1].st_mode)] rfiles = [x for x in vfs_ls if not stat.S_ISDIR(x[1].st_mode)]
rdirs = [x for x in vfs_ls if stat.S_ISDIR(x[1].st_mode)] rdirs = [x for x in vfs_ls if stat.S_ISDIR(x[1].st_mode)]
@ -177,7 +187,7 @@ class VFS(object):
wrel = (rel + "/" + rdir).lstrip("/") wrel = (rel + "/" + rdir).lstrip("/")
wrem = (rem + "/" + rdir).lstrip("/") wrem = (rem + "/" + rdir).lstrip("/")
for x in self.walk(wrel, wrem, uname, scandir, lstat): for x in self.walk(wrel, wrem, seen, uname, dots, scandir, lstat):
yield x yield x
for n, vfs in sorted(vfs_virt.items()): for n, vfs in sorted(vfs_virt.items()):
@ -185,14 +195,16 @@ class VFS(object):
continue continue
wrel = (rel + "/" + n).lstrip("/") wrel = (rel + "/" + n).lstrip("/")
for x in vfs.walk(wrel, "", uname, scandir, lstat): for x in vfs.walk(wrel, "", seen, uname, dots, scandir, lstat):
yield x yield x
def zipgen(self, vrem, flt, uname, dots, scandir): def zipgen(self, vrem, flt, uname, dots, scandir):
if flt: if flt:
flt = {k: True for k in flt} flt = {k: True for k in flt}
for vpath, apath, files, rd, vd in self.walk("", vrem, uname, dots, scandir): for vpath, apath, files, rd, vd in self.walk(
"", vrem, [], uname, dots, scandir, False
):
if flt: if flt:
files = [x for x in files if x[0] in flt] files = [x for x in files if x[0] in flt]
@ -616,13 +628,13 @@ class AuthSrv(object):
continue continue
atop = vn.realpath atop = vn.realpath
g = vn.walk("", "", u, True, not self.args.no_scandir, lstat=False) g = vn.walk("", "", [], u, True, not self.args.no_scandir, False)
for vpath, apath, files, _, _ in g: for vpath, apath, files, _, _ in g:
fnames = [n[0] for n in files] fnames = [n[0] for n in files]
vpaths = [vpath + "/" + n for n in fnames] if vpath else fnames vpaths = [vpath + "/" + n for n in fnames] if vpath else fnames
vpaths = [vtop + x for x in vpaths] vpaths = [vtop + x for x in vpaths]
apaths = [os.path.join(apath, n) for n in fnames] apaths = [os.path.join(apath, n) for n in fnames]
files = list(zip(vpaths, apaths)) files = [[vpath + "/", apath + os.sep]] + list(zip(vpaths, apaths))
if flag_ln: if flag_ln:
files = [x for x in files if not x[1].startswith(atop + os.sep)] files = [x for x in files if not x[1].startswith(atop + os.sep)]

View file

@ -16,6 +16,7 @@ import calendar
from .__init__ import E, PY2, WINDOWS, ANYWIN from .__init__ import E, PY2, WINDOWS, ANYWIN
from .util import * # noqa # pylint: disable=unused-wildcard-import from .util import * # noqa # pylint: disable=unused-wildcard-import
from .authsrv import AuthSrv
from .szip import StreamZip from .szip import StreamZip
from .star import StreamTar from .star import StreamTar
@ -35,12 +36,12 @@ class HttpCli(object):
def __init__(self, conn): def __init__(self, conn):
self.t0 = time.time() self.t0 = time.time()
self.conn = conn self.conn = conn
self.s = conn.s self.s = conn.s # type: socket
self.sr = conn.sr self.sr = conn.sr # type: Unrecv
self.ip = conn.addr[0] self.ip = conn.addr[0]
self.addr = conn.addr self.addr = conn.addr # type: tuple[str, int]
self.args = conn.args self.args = conn.args
self.auth = conn.auth self.auth = conn.auth # type: AuthSrv
self.ico = conn.ico self.ico = conn.ico
self.thumbcli = conn.thumbcli self.thumbcli = conn.thumbcli
self.log_func = conn.log_func self.log_func = conn.log_func
@ -1418,7 +1419,7 @@ class HttpCli(object):
try: try:
vn, rem = self.auth.vfs.get(top, self.uname, True, False) vn, rem = self.auth.vfs.get(top, self.uname, True, False)
fsroot, vfs_ls, vfs_virt = vn.ls( fsroot, vfs_ls, vfs_virt = vn.ls(
rem, self.uname, not self.args.no_scandir, True rem, self.uname, not self.args.no_scandir, incl_wo=True
) )
except: except:
vfs_ls = [] vfs_ls = []
@ -1585,7 +1586,7 @@ class HttpCli(object):
return self.tx_zip(k, v, vn, rem, [], self.args.ed) return self.tx_zip(k, v, vn, rem, [], self.args.ed)
fsroot, vfs_ls, vfs_virt = vn.ls( fsroot, vfs_ls, vfs_virt = vn.ls(
rem, self.uname, not self.args.no_scandir, True rem, self.uname, not self.args.no_scandir, incl_wo=True
) )
stats = {k: v for k, v in vfs_ls} stats = {k: v for k, v in vfs_ls}
vfs_ls = [x[0] for x in vfs_ls] vfs_ls = [x[0] for x in vfs_ls]

View file

@ -25,8 +25,8 @@ except ImportError:
sys.exit(1) sys.exit(1)
from .__init__ import E, MACOS from .__init__ import E, MACOS
from .httpconn import HttpConn
from .authsrv import AuthSrv from .authsrv import AuthSrv
from .httpconn import HttpConn
class HttpSrv(object): class HttpSrv(object):

View file

@ -95,9 +95,11 @@ class SvcHub(object):
break break
if n == 3: if n == 3:
print("waiting for thumbsrv...") print("waiting for thumbsrv (10sec)...")
print("nailed it") print("nailed it", end="")
finally:
print("\033[0m")
def _log_disabled(self, src, msg, c=0): def _log_disabled(self, src, msg, c=0):
pass pass

View file

@ -11,7 +11,7 @@ from textwrap import dedent
from argparse import Namespace from argparse import Namespace
from tests import util as tu from tests import util as tu
from copyparty.authsrv import AuthSrv from copyparty.authsrv import AuthSrv, VFS
from copyparty import util from copyparty import util
@ -47,6 +47,7 @@ class TestVFS(unittest.TestCase):
self.assertEqual(util.undot(query), response) self.assertEqual(util.undot(query), response)
def ls(self, vfs, vpath, uname): def ls(self, vfs, vpath, uname):
# type: (VFS, str, str) -> tuple[str, str, str]
"""helper for resolving and listing a folder""" """helper for resolving and listing a folder"""
vn, rem = vfs.get(vpath, uname, True, False) vn, rem = vfs.get(vpath, uname, True, False)
r1 = vn.ls(rem, uname, False) r1 = vn.ls(rem, uname, False)