From 5239e7ac0c12d7e835fac85c06a1568807dccecf Mon Sep 17 00:00:00 2001 From: ed Date: Thu, 18 Apr 2024 00:07:56 +0000 Subject: [PATCH] separate registry mutex for faster access also fix a harmless toctou in handle_json where clients could get stuck hanging for a bit longer than necessary --- copyparty/up2k.py | 100 ++++++++++++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 39 deletions(-) diff --git a/copyparty/up2k.py b/copyparty/up2k.py index a048c9da..3222618c 100644 --- a/copyparty/up2k.py +++ b/copyparty/up2k.py @@ -139,6 +139,7 @@ class Up2k(object): self.need_rescan: set[str] = set() self.db_act = 0.0 + self.reg_mutex = threading.Lock() self.registry: dict[str, dict[str, dict[str, Any]]] = {} self.flags: dict[str, dict[str, Any]] = {} self.droppable: dict[str, list[str]] = {} @@ -203,11 +204,15 @@ class Up2k(object): Daemon(self.deferred_init, "up2k-deferred-init") def reload(self, rescan_all_vols: bool) -> None: - """mutex me""" + """mutex(main) me""" self.log("reload #{} scheduled".format(self.gid + 1)) all_vols = self.asrv.vfs.all_vols - scan_vols = [k for k, v in all_vols.items() if v.realpath not in self.registry] + with self.reg_mutex: + scan_vols = [ + k for k, v in all_vols.items() if v.realpath not in self.registry + ] + if rescan_all_vols: scan_vols = list(all_vols.keys()) @@ -220,7 +225,7 @@ class Up2k(object): if self.stop: # up-mt consistency not guaranteed if init is interrupted; # drop caches for a full scan on next boot - with self.mutex: + with self.mutex, self.reg_mutex: self._drop_caches() if self.pp: @@ -289,7 +294,7 @@ class Up2k(object): return json.dumps(ret, indent=4) def get_unfinished_by_user(self, uname, ip) -> str: - if PY2 or not self.mutex.acquire(timeout=2): + if PY2 or not self.reg_mutex.acquire(timeout=2): return '[{"timeout":1}]' ret: list[tuple[int, str, int, int, int]] = [] @@ -318,7 +323,7 @@ class Up2k(object): ) ret.append(zt5) finally: - self.mutex.release() + self.reg_mutex.release() ret.sort(reverse=True) ret2 = [ @@ -328,7 +333,7 @@ class Up2k(object): return json.dumps(ret2, indent=0) def get_unfinished(self) -> str: - if PY2 or not self.mutex.acquire(timeout=0.5): + if PY2 or not self.reg_mutex.acquire(timeout=0.5): return "" ret: dict[str, tuple[int, int]] = {} @@ -350,17 +355,17 @@ class Up2k(object): ret[ptop] = (nbytes, nfiles) finally: - self.mutex.release() + self.reg_mutex.release() return json.dumps(ret, indent=4) def get_volsize(self, ptop: str) -> tuple[int, int]: - with self.mutex: + with self.reg_mutex: return self._get_volsize(ptop) def get_volsizes(self, ptops: list[str]) -> list[tuple[int, int]]: ret = [] - with self.mutex: + with self.reg_mutex: for ptop in ptops: ret.append(self._get_volsize(ptop)) @@ -388,7 +393,7 @@ class Up2k(object): def _rescan( self, all_vols: dict[str, VFS], scan_vols: list[str], wait: bool, fscan: bool ) -> str: - """mutex me""" + """mutex(main) me""" if not wait and self.pp: return "cannot initiate; scan is already in progress" @@ -670,7 +675,7 @@ class Up2k(object): self.log(msg, c=3) live_vols = [] - with self.mutex: + with self.mutex, self.reg_mutex: # only need to protect register_vpath but all in one go feels right for vol in vols: try: @@ -712,7 +717,7 @@ class Up2k(object): if self.args.re_dhash or [zv for zv in vols if "e2tsr" in zv.flags]: self.args.re_dhash = False - with self.mutex: + with self.mutex, self.reg_mutex: self._drop_caches() for vol in vols: @@ -850,6 +855,7 @@ class Up2k(object): def register_vpath( self, ptop: str, flags: dict[str, Any] ) -> Optional[tuple["sqlite3.Cursor", str]]: + """mutex(main,reg) me""" histpath = self.asrv.vfs.histtab.get(ptop) if not histpath: self.log("no histpath for [{}]".format(ptop)) @@ -1033,7 +1039,9 @@ class Up2k(object): dev = cst.st_dev if vol.flags.get("xdev") else 0 with self.mutex: - reg = self.register_vpath(top, vol.flags) + with self.reg_mutex: + reg = self.register_vpath(top, vol.flags) + assert reg and self.pp cur, db_path = reg @@ -1630,7 +1638,7 @@ class Up2k(object): def _build_tags_index(self, vol: VFS) -> tuple[int, int, bool]: ptop = vol.realpath - with self.mutex: + with self.mutex, self.reg_mutex: reg = self.register_vpath(ptop, vol.flags) assert reg and self.pp @@ -1651,6 +1659,7 @@ class Up2k(object): return ret def _drop_caches(self) -> None: + """mutex(main,reg) me""" self.log("dropping caches for a full filesystem scan") for vol in self.asrv.vfs.all_vols.values(): reg = self.register_vpath(vol.realpath, vol.flags) @@ -1826,7 +1835,7 @@ class Up2k(object): params: tuple[Any, ...], flt: int, ) -> tuple[tempfile.SpooledTemporaryFile[bytes], int]: - """mutex me""" + """mutex(main) me""" n = 0 c2 = cur.connection.cursor() tf = tempfile.SpooledTemporaryFile(1024 * 1024 * 8, "w+b", prefix="cpp-tq-") @@ -2160,7 +2169,7 @@ class Up2k(object): ip: str, at: float, ) -> int: - """will mutex""" + """will mutex(main)""" assert self.mtag try: @@ -2192,7 +2201,7 @@ class Up2k(object): abspath: str, tags: dict[str, Union[str, float]], ) -> int: - """mutex me""" + """mutex(main) me""" assert self.mtag if not bos.path.isfile(abspath): @@ -2477,28 +2486,33 @@ class Up2k(object): cur.connection.commit() - def _job_volchk(self, cj: dict[str, Any]) -> None: - if not self.register_vpath(cj["ptop"], cj["vcfg"]): - if cj["ptop"] not in self.registry: - raise Pebkac(410, "location unavailable") - def handle_json(self, cj: dict[str, Any], busy_aps: set[str]) -> dict[str, Any]: self.busy_aps = busy_aps + got_lock = False try: # bit expensive; 3.9=10x 3.11=2x if self.mutex.acquire(timeout=10): - self._job_volchk(cj) - self.mutex.release() + got_lock = True + with self.reg_mutex: + return self._handle_json(cj) else: t = "cannot receive uploads right now;\nserver busy with {}.\nPlease wait; the client will retry..." raise Pebkac(503, t.format(self.blocked or "[unknown]")) except TypeError: if not PY2: raise - with self.mutex: - self._job_volchk(cj) + with self.mutex, self.reg_mutex: + return self._handle_json(cj) + finally: + if got_lock: + self.mutex.release() + def _handle_json(self, cj: dict[str, Any]) -> dict[str, Any]: ptop = cj["ptop"] + if not self.register_vpath(ptop, cj["vcfg"]): + if ptop not in self.registry: + raise Pebkac(410, "location unavailable") + cj["name"] = sanitize_fn(cj["name"], "", [".prologue.html", ".epilogue.html"]) cj["poke"] = now = self.db_act = self.vol_act[ptop] = time.time() wark = self._get_wark(cj) @@ -2513,7 +2527,7 @@ class Up2k(object): # refuse out-of-order / multithreaded uploading if sprs False sprs = self.fstab.get(pdir) != "ng" - with self.mutex: + if True: jcur = self.cur.get(ptop) reg = self.registry[ptop] vfs = self.asrv.vfs.all_vols[cj["vtop"]] @@ -2951,7 +2965,7 @@ class Up2k(object): def handle_chunk( self, ptop: str, wark: str, chash: str ) -> tuple[int, list[int], str, float, bool]: - with self.mutex: + with self.mutex, self.reg_mutex: self.db_act = self.vol_act[ptop] = time.time() job = self.registry[ptop].get(wark) if not job: @@ -2994,7 +3008,7 @@ class Up2k(object): return chunksize, ofs, path, job["lmod"], job["sprs"] def release_chunk(self, ptop: str, wark: str, chash: str) -> bool: - with self.mutex: + with self.reg_mutex: job = self.registry[ptop].get(wark) if job: job["busy"].pop(chash, None) @@ -3002,7 +3016,7 @@ class Up2k(object): return True def confirm_chunk(self, ptop: str, wark: str, chash: str) -> tuple[int, str]: - with self.mutex: + with self.mutex, self.reg_mutex: self.db_act = self.vol_act[ptop] = time.time() try: job = self.registry[ptop][wark] @@ -3025,16 +3039,16 @@ class Up2k(object): if self.args.nw: self.regdrop(ptop, wark) - return ret, dst return ret, dst def finish_upload(self, ptop: str, wark: str, busy_aps: set[str]) -> None: self.busy_aps = busy_aps - with self.mutex: + with self.mutex, self.reg_mutex: self._finish_upload(ptop, wark) def _finish_upload(self, ptop: str, wark: str) -> None: + """mutex(main,reg) me""" try: job = self.registry[ptop][wark] pdir = djoin(job["ptop"], job["prel"]) @@ -3107,6 +3121,7 @@ class Up2k(object): cur.connection.commit() def regdrop(self, ptop: str, wark: str) -> None: + """mutex(main,reg) me""" olds = self.droppable[ptop] if wark: olds.append(wark) @@ -3201,6 +3216,7 @@ class Up2k(object): at: float, skip_xau: bool = False, ) -> None: + """mutex(main) me""" self.db_rm(db, rd, fn, sz) sql = "insert into up values (?,?,?,?,?,?,?)" @@ -3314,7 +3330,7 @@ class Up2k(object): vn, rem = self.asrv.vfs.get(vpath, uname, *permsets[0]) vn, rem = vn.get_dbv(rem) ptop = vn.realpath - with self.mutex: + with self.mutex, self.reg_mutex: abrt_cfg = self.flags.get(ptop, {}).get("u2abort", 1) addr = (ip or "\n") if abrt_cfg in (1, 2) else "" user = (uname or "\n") if abrt_cfg in (1, 3) else "" @@ -3418,7 +3434,7 @@ class Up2k(object): continue n_files += 1 - with self.mutex: + with self.mutex, self.reg_mutex: cur = None try: ptop = dbv.realpath @@ -3536,6 +3552,7 @@ class Up2k(object): def _mv_file( self, uname: str, svp: str, dvp: str, curs: set["sqlite3.Cursor"] ) -> str: + """mutex(main) me; will mutex(reg)""" svn, srem = self.asrv.vfs.get(svp, uname, True, False, True) svn, srem = svn.get_dbv(srem) @@ -3616,7 +3633,9 @@ class Up2k(object): if c2 and c2 != c1: self._copy_tags(c1, c2, w) - has_dupes = self._forget_file(svn.realpath, srem, c1, w, is_xvol, fsize) + with self.reg_mutex: + has_dupes = self._forget_file(svn.realpath, srem, c1, w, is_xvol, fsize) + if not is_xvol: has_dupes = self._relink(w, svn.realpath, srem, dabs) @@ -3746,7 +3765,10 @@ class Up2k(object): drop_tags: bool, sz: int, ) -> bool: - """forgets file in db, fixes symlinks, does not delete""" + """ + mutex(main,reg) me + forgets file in db, fixes symlinks, does not delete + """ srd, sfn = vsplit(vrem) has_dupes = False self.log("forgetting {}".format(vrem)) @@ -4071,7 +4093,7 @@ class Up2k(object): self.do_snapshot() def do_snapshot(self) -> None: - with self.mutex: + with self.mutex, self.reg_mutex: for k, reg in self.registry.items(): self._snap_reg(k, reg) @@ -4222,7 +4244,7 @@ class Up2k(object): ) -> bool: ptop, vtop, flags, rd, fn, ip, at, usr, skip_xau = task # self.log("hashq {} pop {}/{}/{}".format(self.n_hashq, ptop, rd, fn)) - with self.mutex: + with self.mutex, self.reg_mutex: if not self.register_vpath(ptop, flags): return True @@ -4240,7 +4262,7 @@ class Up2k(object): wark = up2k_wark_from_hashlist(self.salt, inf.st_size, hashes) - with self.mutex: + with self.mutex, self.reg_mutex: self.idx_wark( self.flags[ptop], rd,