diff --git a/copyparty/__main__.py b/copyparty/__main__.py index 7e5c42c0..d4f765b2 100755 --- a/copyparty/__main__.py +++ b/copyparty/__main__.py @@ -954,6 +954,7 @@ def add_auth(ap): ap2 = ap.add_argument_group('IdP / identity provider / user authentication options') ap2.add_argument("--idp-h-usr", metavar="HN", type=u, default="", help="bypass the copyparty authentication checks and assume the request-header \033[33mHN\033[0m contains the username of the requesting user (for use with authentik/oauth/...)\n\033[1;31mWARNING:\033[0m if you enable this, make sure clients are unable to specify this header themselves; must be washed away and replaced by a reverse-proxy") ap2.add_argument("--idp-h-grp", metavar="HN", type=u, default="", help="assume the request-header \033[33mHN\033[0m contains the groupname of the requesting user; can be referenced in config files for group-based access control") + ap2.add_argument("--idp-h-sep", metavar="RE", type=u, default="|:;+,", help="if there are multiple groups in \033[33m--idp-h-grp\033[0m, they are separated by one of the characters in \033[33mRE\033[0m") def add_zeroconf(ap): diff --git a/copyparty/authsrv.py b/copyparty/authsrv.py index 2587aa85..f57049e1 100644 --- a/copyparty/authsrv.py +++ b/copyparty/authsrv.py @@ -795,7 +795,8 @@ class AuthSrv(object): self.idp_vols: dict[str, str] = {} # vpath->abspath # all users/groups observed since last restart - self.idp_accs: dict[str, str] = {} # username->groupname + self.idp_accs: dict[str, list[str]] = {} # username->groupnames + self.idp_usr_gh: dict[str, str] = {} # username->group-header-value (cache) self.mutex = threading.Lock() self.reload() @@ -820,17 +821,21 @@ class AuthSrv(object): if uname in self.acct: return False - if self.idp_accs.get(uname) == gname: + if self.idp_usr_gh.get(uname) == gname: return False + gnames = [x.strip() for x in self.args.idp_h_sep.split(gname)] + gnames.sort() + with self.mutex: - if self.idp_accs.get(uname) == gname: + self.idp_usr_gh[uname] = gname + if self.idp_accs.get(uname) == gnames: return False - self.idp_accs[uname] = gname + self.idp_accs[uname] = gnames t = "reinitializing due to new user from IdP: [%s:%s]" - self.log(t % (uname, gname), 3) + self.log(t % (uname, gnames), 3) if not broker: # only true for tests @@ -847,15 +852,19 @@ class AuthSrv(object): mount: dict[str, str], daxs: dict[str, AXS], mflags: dict[str, dict[str, Any]], - un_gn: dict[str, str], + un_gns: dict[str, list[str]], ) -> list[tuple[str, str, str, str]]: ret: list[tuple[str, str, str, str]] = [] visited = set() src0 = src # abspath dst0 = dst # vpath - # +('','') to ensure volume creation if there's no users - for un, gn in list(un_gn.items()) + [("", "")]: + un_gn = [(un, gn) for un, gns in un_gns.items() for gn in gns] + if not un_gn: + # ensure volume creation if there's no users + un_gn = [("", "")] + + for un, gn in un_gn: # if ap/vp has a user/group placeholder, make sure to keep # track so the same user/gruop is mapped when setting perms; # otherwise clear un/gn to indicate it's a regular volume @@ -952,17 +961,21 @@ class AuthSrv(object): self, acct: dict[str, str], grps: dict[str, list[str]], - ) -> dict[str, str]: + ) -> dict[str, list[str]]: """ generate list of all confirmed pairs of username/groupname seen since last restart; in case of conflicting group memberships then it is selected as follows: * any non-zero value from IdP group header * otherwise take --grps / [groups] """ - ret = self.idp_accs.copy() - ret.update({zs: "" for zs in acct if zs not in ret}) + ret = {un:gns[:] for un, gns in self.idp_accs.items()} + ret.update({zs: [""] for zs in acct if zs not in ret}) for gn, uns in grps.items(): - ret.update({un: gn for un in uns if not ret.get(un)}) + for un in uns: + try: + ret[un].append(gn) + except: + ret[un] = [gn] return ret @@ -1176,7 +1189,7 @@ class AuthSrv(object): lvl: str, uname: str, vols: list[tuple[str, str, str, str]], - un_gn: dict[str, str], + un_gns: dict[str, list[str]], axs: dict[str, AXS], flags: dict[str, dict[str, Any]], ) -> None: @@ -1212,8 +1225,8 @@ class AuthSrv(object): for un in uname.replace(",", " ").strip().split(): if un.startswith("@"): grp = un[1:] - uns = [x[0] for x in un_gn.items() if x[1] == grp] - if not uns and grp != "${g}": + uns = [x[0] for x in un_gns.items() if grp in x[1]] + if not uns and grp != "${g}" and not self.args.idp_h_grp: t = "group [%s] must be defined with --grp argument (or in a [groups] config section)" raise CfgEx(t % (grp,)) @@ -1222,13 +1235,14 @@ class AuthSrv(object): unames.append(un) # unames may still contain ${u} and ${g} so now expand those; - # need ("*","") to match "*" in unames - un_gn = un_gn.copy() - un_gn["*"] = un_gn.get("*", "") + un_gn = [(un, gn) for un, gns in un_gns.items() for gn in gns] + if "*" not in un_gns: + # need ("*","") to match "*" in unames + un_gn.append(("*", "")) for _, dst, vu, vg in vols: unames2 = set() - for un, gn in un_gn.items(): + for un, gn in un_gn: # if vu/vg (volume user/group) is non-null, # then each non-null value corresponds to # ${u}/${g}; consider this a filter to diff --git a/copyparty/svchub.py b/copyparty/svchub.py index 88d0e56b..fabb92b7 100644 --- a/copyparty/svchub.py +++ b/copyparty/svchub.py @@ -460,6 +460,18 @@ class SvcHub(object): if ptn: setattr(self.args, k, re.compile(ptn)) + for k in ["idp_h_sep"]: + ptn = getattr(self.args, k) + if "]" in ptn: + ptn = "]" + ptn.replace("]", "") + if "[" in ptn: + ptn = ptn.replace("[", "") + "[" + if "-" in ptn: + ptn = ptn.replace("-", "") + "-" + + ptn = ptn.replace("\\", "\\\\").replace("^", "\\^") + setattr(self.args, k, re.compile("[%s]" % (ptn,))) + try: zf1, zf2 = self.args.rm_retry.split("/") self.args.rm_re_t = float(zf1) diff --git a/tests/res/idp/5.conf b/tests/res/idp/5.conf new file mode 100644 index 00000000..135f3ea4 --- /dev/null +++ b/tests/res/idp/5.conf @@ -0,0 +1,21 @@ +# -*- mode: yaml -*- +# vim: ft=yaml: + +[global] + idp-h-usr: x-idp-user + idp-h-grp: x-idp-group + +[/ga] + /ga + accs: + r: @ga + +[/gb] + /gb + accs: + r: @gb + +[/g] + /g + accs: + r: @ga, @gb diff --git a/tests/test_idp.py b/tests/test_idp.py index 1e746bdc..79f2669d 100644 --- a/tests/test_idp.py +++ b/tests/test_idp.py @@ -139,3 +139,33 @@ class TestVFS(unittest.TestCase): self.assertEqual(self.nav(au, "vu/iua").realpath, "/u-iua") self.assertEqual(self.nav(au, "vg/iga1").realpath, "/g1-iga") self.assertEqual(self.nav(au, "vg/iga2").realpath, "/g2-iga") + + def test_5(self): + """ + one IdP user in multiple groups + """ + _, cfgdir, xcfg = self.prep() + au = AuthSrv(Cfg(c=[cfgdir + "/5.conf"], **xcfg), self.log) + + self.assertEqual(au.vfs.vpath, "") + self.assertEqual(au.vfs.realpath, "") + self.assertNodes(au.vfs, ["g", "ga", "gb"]) + self.assertAxs(au.vfs.axs, []) + + au.idp_checkin(None, "iua", "ga") + self.assertNodes(au.vfs, ["g", "ga", "gb"]) + self.assertAxsAt(au, "g", [["iua"]]) + self.assertAxsAt(au, "ga", [["iua"]]) + self.assertAxsAt(au, "gb", []) + + au.idp_checkin(None, "iua", "gb") + self.assertNodes(au.vfs, ["g", "ga", "gb"]) + self.assertAxsAt(au, "g", [["iua"]]) + self.assertAxsAt(au, "ga", []) + self.assertAxsAt(au, "gb", [["iua"]]) + + au.idp_checkin(None, "iua", "ga|gb") + self.assertNodes(au.vfs, ["g", "ga", "gb"]) + self.assertAxsAt(au, "g", [["iua"]]) + self.assertAxsAt(au, "ga", [["iua"]]) + self.assertAxsAt(au, "gb", [["iua"]]) diff --git a/tests/util.py b/tests/util.py index 159675fa..c1ac464e 100644 --- a/tests/util.py +++ b/tests/util.py @@ -146,6 +146,7 @@ class Cfg(Namespace): E=E, dbd="wal", fk_salt="a" * 16, + idp_h_sep=re.compile("[|:;+,]"), lang="eng", log_badpwd=1, logout=573,