diff --git a/copyparty/authsrv.py b/copyparty/authsrv.py index c7114bac..86a42206 100644 --- a/copyparty/authsrv.py +++ b/copyparty/authsrv.py @@ -913,7 +913,7 @@ class AuthSrv(object): self._reload() return True - broker.ask("_reload_blocking", False).get() + broker.ask("reload", False, True).get() return True def _map_volume_idp( @@ -2389,7 +2389,7 @@ class AuthSrv(object): self._reload() return True, "new password OK" - broker.ask("_reload_blocking", False, False).get() + broker.ask("reload", False, False).get() return True, "new password OK" def setup_chpw(self, acct: dict[str, str]) -> None: diff --git a/copyparty/httpcli.py b/copyparty/httpcli.py index 808cb008..075dba1b 100644 --- a/copyparty/httpcli.py +++ b/copyparty/httpcli.py @@ -4591,7 +4591,7 @@ class HttpCli(object): if self.args.no_reload: raise Pebkac(403, "the reload feature is disabled in server config") - x = self.conn.hsrv.broker.ask("reload") + x = self.conn.hsrv.broker.ask("reload", True, True) return self.redirect("", "?h", x.get(), "return to", False) def tx_stack(self) -> bool: @@ -4879,7 +4879,7 @@ class HttpCli(object): cur.connection.commit() if reload: - self.conn.hsrv.broker.ask("_reload_blocking", False, False).get() + self.conn.hsrv.broker.ask("reload", False, False).get() self.conn.hsrv.broker.ask("up2k.wake_rescanner").get() self.redirect(self.args.SRS + "?shares") @@ -4970,7 +4970,7 @@ class HttpCli(object): cur.execute(q, (skey, fn)) cur.connection.commit() - self.conn.hsrv.broker.ask("_reload_blocking", False, False).get() + self.conn.hsrv.broker.ask("reload", False, False).get() self.conn.hsrv.broker.ask("up2k.wake_rescanner").get() fn = quotep(fns[0]) if len(fns) == 1 else "" diff --git a/copyparty/svchub.py b/copyparty/svchub.py index 4ad5d6e5..cf5db4a1 100644 --- a/copyparty/svchub.py +++ b/copyparty/svchub.py @@ -112,7 +112,7 @@ class SvcHub(object): self.stopping = False self.stopped = False self.reload_req = False - self.reloading = 0 + self.reload_mutex = threading.Lock() self.stop_cond = threading.Condition() self.nsigs = 3 self.retcode = 0 @@ -1004,20 +1004,8 @@ class SvcHub(object): except: self.log("root", "ssdp startup failed;\n" + min_ex(), 3) - def reload(self) -> str: - with self.up2k.mutex: - if self.reloading: - return "cannot reload; already in progress" - self.reloading = 1 - - Daemon(self._reload, "reloading") - return "reload initiated" - - def _reload(self, rescan_all_vols: bool = True, up2k: bool = True) -> None: - with self.up2k.mutex: - if self.reloading != 1: - return - self.reloading = 2 + def reload(self, rescan_all_vols: bool, up2k: bool) -> None: + with self.reload_mutex: self.log("root", "reloading config") self.asrv.reload(9 if up2k else 4) if up2k: @@ -1025,20 +1013,6 @@ class SvcHub(object): else: self.log("root", "reload done") self.broker.reload() - self.reloading = 0 - - def _reload_blocking(self, rescan_all_vols: bool = True, up2k: bool = True) -> None: - while True: - with self.up2k.mutex: - if self.reloading < 2: - self.reloading = 1 - break - time.sleep(0.05) - - # try to handle multiple pending IdP reloads at once: - time.sleep(0.2) - - self._reload(rescan_all_vols=rescan_all_vols, up2k=up2k) def _reload_sessions(self) -> None: with self.asrv.mutex: @@ -1052,7 +1026,7 @@ class SvcHub(object): if self.reload_req: self.reload_req = False - self.reload() + self.reload(True, True) self.shutdown() diff --git a/copyparty/up2k.py b/copyparty/up2k.py index ced5062d..28ed66f2 100644 --- a/copyparty/up2k.py +++ b/copyparty/up2k.py @@ -89,6 +89,8 @@ zsg = "avif,avifs,bmp,gif,heic,heics,heif,heifs,ico,j2p,j2k,jp2,jpeg,jpg,jpx,png CV_EXTS = set(zsg.split(",")) +SBUSY = "cannot receive uploads right now;\nserver busy with %s.\nPlease wait; the client will retry..." + HINT_HISTPATH = "you could try moving the database to another location (preferably an SSD or NVME drive) using either the --hist argument (global option for all volumes), or the hist volflag (just for this volume)" @@ -125,12 +127,22 @@ class Up2k(object): self.args = hub.args self.log_func = hub.log + self.vfs = self.asrv.vfs + self.acct = self.asrv.acct + self.iacct = self.asrv.iacct + self.grps = self.asrv.grps + self.salt = self.args.warksalt self.r_hash = re.compile("^[0-9a-zA-Z_-]{44}$") self.gid = 0 + self.gt0 = 0 + self.gt1 = 0 self.stop = False self.mutex = threading.Lock() + self.reload_mutex = threading.Lock() + self.reload_flag = 0 + self.reloading = False self.blocked: Optional[str] = None self.pp: Optional[ProgressPrinter] = None self.rescan_cond = threading.Condition() @@ -203,7 +215,38 @@ class Up2k(object): Daemon(self.deferred_init, "up2k-deferred-init") + def unpp(self) -> None: + self.gt1 = time.time() + if self.pp: + self.pp.end = True + self.pp = None + def reload(self, rescan_all_vols: bool) -> None: + n = 2 if rescan_all_vols else 1 + with self.reload_mutex: + if self.reload_flag < n: + self.reload_flag = n + with self.rescan_cond: + self.rescan_cond.notify_all() + + def _reload_thr(self) -> None: + while self.pp: + time.sleep(0.1) + while True: + with self.reload_mutex: + if not self.reload_flag: + break + rav = self.reload_flag == 2 + self.reload_flag = 0 + gt1 = self.gt1 + with self.mutex: + self._reload(rav) + while gt1 == self.gt1 or self.pp: + time.sleep(0.1) + + self.reloading = False + + def _reload(self, rescan_all_vols: bool) -> None: """mutex(main) me""" self.log("reload #{} scheduled".format(self.gid + 1)) all_vols = self.asrv.vfs.all_vols @@ -228,10 +271,7 @@ class Up2k(object): with self.mutex, self.reg_mutex: self._drop_caches() - if self.pp: - self.pp.end = True - self.pp = None - + self.unpp() return if not self.pp and self.args.exit == "idx": @@ -311,8 +351,8 @@ class Up2k(object): def _active_uploads(self, uname: str) -> list[tuple[float, int, int, str]]: ret = [] - for vtop in self.asrv.vfs.aread[uname]: - vfs = self.asrv.vfs.all_vols.get(vtop) + for vtop in self.vfs.aread.get(uname) or []: + vfs = self.vfs.all_vols.get(vtop) if not vfs: # dbv only continue ptop = vfs.realpath @@ -485,6 +525,12 @@ class Up2k(object): if self.stop: return + with self.reload_mutex: + if self.reload_flag and not self.reloading: + self.reloading = True + zs = "up2k-reload-%d" % (self.gid,) + Daemon(self._reload_thr, zs) + now = time.time() if now < cooldown: # self.log("SR: cd - now = {:.2f}".format(cooldown - now), 5) @@ -521,7 +567,7 @@ class Up2k(object): raise with self.mutex: - for vp, vol in sorted(self.asrv.vfs.all_vols.items()): + for vp, vol in sorted(self.vfs.all_vols.items()): maxage = vol.flags.get("scan") if not maxage: continue @@ -554,7 +600,7 @@ class Up2k(object): if vols: cooldown = now + 10 - err = self.rescan(self.asrv.vfs.all_vols, vols, False, False) + err = self.rescan(self.vfs.all_vols, vols, False, False) if err: for v in vols: self.need_rescan.add(v) @@ -567,7 +613,7 @@ class Up2k(object): def _check_lifetimes(self) -> float: now = time.time() timeout = now + 9001 - for vp, vol in sorted(self.asrv.vfs.all_vols.items()): + for vp, vol in sorted(self.vfs.all_vols.items()): lifetime = vol.flags.get("lifetime") if not lifetime: continue @@ -621,7 +667,7 @@ class Up2k(object): maxage = self.args.shr_rt * 60 low = now - maxage - vn = self.asrv.vfs.nodes.get(self.args.shr.strip("/")) + vn = self.vfs.nodes.get(self.args.shr.strip("/")) active = vn and vn.nodes db = sqlite3.connect(self.args.shr_db, timeout=2) @@ -646,7 +692,7 @@ class Up2k(object): db.commit() if reload: - Daemon(self.hub._reload_blocking, "sharedrop", (False, False)) + Daemon(self.hub.reload, "sharedrop", (False, False)) q = "select min(t1) from sh where t1 > ?" (earliest,) = cur.execute(q, (1,)).fetchone() @@ -672,7 +718,7 @@ class Up2k(object): return 2 ret = 9001 - for _, vol in sorted(self.asrv.vfs.all_vols.items()): + for _, vol in sorted(self.vfs.all_vols.items()): rp = vol.realpath cur = self.cur.get(rp) if not cur: @@ -774,6 +820,8 @@ class Up2k(object): with self.mutex: gid = self.gid + self.gt0 = time.time() + nspin = 0 while True: nspin += 1 @@ -796,6 +844,11 @@ class Up2k(object): if gid: self.log("reload #%d running" % (gid,)) + self.vfs = self.asrv.vfs + self.acct = self.asrv.acct + self.iacct = self.asrv.iacct + self.grps = self.asrv.grps + vols = list(all_vols.values()) t0 = time.time() have_e2d = False @@ -859,7 +912,7 @@ class Up2k(object): self._drop_caches() for vol in vols: - if self.stop: + if self.stop or gid != self.gid: break en = set(vol.flags.get("mte", {})) @@ -990,7 +1043,7 @@ class Up2k(object): if self.mtag: Daemon(self._run_all_mtp, "up2k-mtp-scan", (gid,)) else: - self.pp = None + self.unpp() return have_e2d @@ -998,7 +1051,7 @@ class Up2k(object): self, ptop: str, flags: dict[str, Any] ) -> Optional[tuple["sqlite3.Cursor", str]]: """mutex(main,reg) me""" - histpath = self.asrv.vfs.histtab.get(ptop) + histpath = self.vfs.histtab.get(ptop) if not histpath: self.log("no histpath for [{}]".format(ptop)) return None @@ -1011,7 +1064,7 @@ class Up2k(object): return None vpath = "?" - for k, v in self.asrv.vfs.all_vols.items(): + for k, v in self.vfs.all_vols.items(): if v.realpath == ptop: vpath = k @@ -1178,7 +1231,7 @@ class Up2k(object): def _verify_db_cache(self, cur: "sqlite3.Cursor", vpath: str) -> None: # check if list of intersecting volumes changed since last use; drop caches if so prefix = (vpath + "/").lstrip("/") - zsl = [x for x in self.asrv.vfs.all_vols if x.startswith(prefix)] + zsl = [x for x in self.vfs.all_vols if x.startswith(prefix)] zsl = [x[len(prefix) :] for x in zsl] zsl.sort() zb = hashlib.sha1("\n".join(zsl).encode("utf-8", "replace")).digest() @@ -1223,7 +1276,7 @@ class Up2k(object): if d != vol and (d.vpath.startswith(vol.vpath + "/") or not vol.vpath) ] excl += [absreal(x) for x in excl] - excl += list(self.asrv.vfs.histtab.values()) + excl += list(self.vfs.histtab.values()) if WINDOWS: excl = [x.replace("/", "\\") for x in excl] else: @@ -1733,7 +1786,7 @@ class Up2k(object): excl = [ d[len(vol.vpath) :].lstrip("/") - for d in self.asrv.vfs.all_vols + for d in self.vfs.all_vols if d != vol.vpath and (d.startswith(vol.vpath + "/") or not vol.vpath) ] qexa: list[str] = [] @@ -1885,7 +1938,7 @@ class Up2k(object): def _drop_caches(self) -> None: """mutex(main,reg) me""" self.log("dropping caches for a full filesystem scan") - for vol in self.asrv.vfs.all_vols.values(): + for vol in self.vfs.all_vols.values(): reg = self.register_vpath(vol.realpath, vol.flags) if not reg: continue @@ -2113,7 +2166,7 @@ class Up2k(object): self._run_one_mtp(ptop, gid) vtop = "\n" - for vol in self.asrv.vfs.all_vols.values(): + for vol in self.vfs.all_vols.values(): if vol.realpath == ptop: vtop = vol.vpath if "running mtp" in self.volstate.get(vtop, ""): @@ -2123,7 +2176,7 @@ class Up2k(object): msg = "mtp finished in {:.2f} sec ({})" self.log(msg.format(td, s2hms(td, True))) - self.pp = None + self.unpp() if self.args.exit == "idx": self.hub.sigterm() @@ -2765,6 +2818,9 @@ class Up2k(object): ) -> dict[str, Any]: # busy_aps is u2fh (always undefined if -j0) so this is safe self.busy_aps = busy_aps + if self.reload_flag or self.reloading: + raise Pebkac(503, SBUSY % ("fs-reload",)) + got_lock = False try: # bit expensive; 3.9=10x 3.11=2x @@ -2773,8 +2829,7 @@ class Up2k(object): with self.reg_mutex: ret = self._handle_json(cj) else: - t = "cannot receive uploads right now;\nserver busy with {}.\nPlease wait; the client will retry..." - raise Pebkac(503, t.format(self.blocked or "[unknown]")) + raise Pebkac(503, SBUSY % (self.blocked or "[unknown]",)) except TypeError: if not PY2: raise @@ -2816,7 +2871,7 @@ class Up2k(object): if True: jcur = self.cur.get(ptop) reg = self.registry[ptop] - vfs = self.asrv.vfs.all_vols[cj["vtop"]] + vfs = self.vfs.all_vols[cj["vtop"]] n4g = bool(vfs.flags.get("noforget")) noclone = bool(vfs.flags.get("noclone")) rand = vfs.flags.get("rand") or cj.get("rand") @@ -2840,7 +2895,7 @@ class Up2k(object): alts: list[tuple[int, int, dict[str, Any], "sqlite3.Cursor", str, str]] = [] for ptop, cur in vols: - allv = self.asrv.vfs.all_vols + allv = self.vfs.all_vols cvfs = next((v for v in allv.values() if v.realpath == ptop), vfs) vtop = cj["vtop"] if cur == jcur else cvfs.vpath @@ -3083,7 +3138,7 @@ class Up2k(object): vp, job["host"], job["user"], - self.asrv.vfs.get_perms(job["vtop"], job["user"]), + self.vfs.get_perms(job["vtop"], job["user"]), job["lmod"], job["size"], job["addr"], @@ -3095,7 +3150,7 @@ class Up2k(object): self.log(t, 1) raise Pebkac(403, t) if hr.get("reloc"): - x = pathmod(self.asrv.vfs, dst, vp, hr["reloc"]) + x = pathmod(self.vfs, dst, vp, hr["reloc"]) if x: zvfs = vfs pdir, _, job["name"], (vfs, rem) = x @@ -3555,7 +3610,7 @@ class Up2k(object): wake_sr = False try: flt = job["life"] - vfs = self.asrv.vfs.all_vols[job["vtop"]] + vfs = self.vfs.all_vols[job["vtop"]] vlt = vfs.flags["lifetime"] if vlt and flt > 1 and flt < vlt: upt -= vlt - flt @@ -3731,7 +3786,7 @@ class Up2k(object): djoin(vtop, rd, fn), host, usr, - self.asrv.vfs.get_perms(djoin(vtop, rd, fn), usr), + self.vfs.get_perms(djoin(vtop, rd, fn), usr), ts, sz, ip, @@ -3841,12 +3896,12 @@ class Up2k(object): partial = "" if not unpost: permsets = [[True, False, False, True]] - vn, rem = self.asrv.vfs.get(vpath, uname, *permsets[0]) + vn, rem = self.vfs.get(vpath, uname, *permsets[0]) vn, rem = vn.get_dbv(rem) else: # unpost with missing permissions? verify with db permsets = [[False, True]] - vn, rem = self.asrv.vfs.get(vpath, uname, *permsets[0]) + vn, rem = self.vfs.get(vpath, uname, *permsets[0]) vn, rem = vn.get_dbv(rem) ptop = vn.realpath with self.mutex, self.reg_mutex: @@ -3951,7 +4006,7 @@ class Up2k(object): vpath, "", uname, - self.asrv.vfs.get_perms(vpath, uname), + self.vfs.get_perms(vpath, uname), stl.st_mtime, st.st_size, ip, @@ -3991,7 +4046,7 @@ class Up2k(object): vpath, "", uname, - self.asrv.vfs.get_perms(vpath, uname), + self.vfs.get_perms(vpath, uname), stl.st_mtime, st.st_size, ip, @@ -4015,7 +4070,7 @@ class Up2k(object): if svp == dvp or dvp.startswith(svp + "/"): raise Pebkac(400, "mv: cannot move parent into subfolder") - svn, srem = self.asrv.vfs.get(svp, uname, True, False, True) + svn, srem = self.vfs.get(svp, uname, True, False, True) svn, srem = svn.get_dbv(srem) sabs = svn.canonical(srem, False) curs: set["sqlite3.Cursor"] = set() @@ -4082,7 +4137,7 @@ class Up2k(object): rem = ap[len(sabs) :].replace(os.sep, "/").lstrip("/") vp = vjoin(dvp, rem) try: - dvn, drem = self.asrv.vfs.get(vp, uname, False, True) + dvn, drem = self.vfs.get(vp, uname, False, True) bos.mkdir(dvn.canonical(drem)) except: pass @@ -4093,10 +4148,10 @@ class Up2k(object): self, uname: str, ip: str, svp: str, dvp: str, curs: set["sqlite3.Cursor"] ) -> str: """mutex(main) me; will mutex(reg)""" - svn, srem = self.asrv.vfs.get(svp, uname, True, False, True) + svn, srem = self.vfs.get(svp, uname, True, False, True) svn, srem = svn.get_dbv(srem) - dvn, drem = self.asrv.vfs.get(dvp, uname, False, True) + dvn, drem = self.vfs.get(dvp, uname, False, True) dvn, drem = dvn.get_dbv(drem) sabs = svn.canonical(srem, False) @@ -4140,7 +4195,7 @@ class Up2k(object): svp, "", uname, - self.asrv.vfs.get_perms(svp, uname), + self.vfs.get_perms(svp, uname), ftime, fsize, ip, @@ -4180,7 +4235,7 @@ class Up2k(object): dvp, "", uname, - self.asrv.vfs.get_perms(dvp, uname), + self.vfs.get_perms(dvp, uname), ftime, fsize, ip, @@ -4293,7 +4348,7 @@ class Up2k(object): dvp, "", uname, - self.asrv.vfs.get_perms(dvp, uname), + self.vfs.get_perms(dvp, uname), ftime, fsize, ip, @@ -4606,7 +4661,7 @@ class Up2k(object): vp_chk, job["host"], job["user"], - self.asrv.vfs.get_perms(vp_chk, job["user"]), + self.vfs.get_perms(vp_chk, job["user"]), job["lmod"], job["size"], job["addr"], @@ -4618,7 +4673,7 @@ class Up2k(object): self.log(t, 1) raise Pebkac(403, t) if hr.get("reloc"): - x = pathmod(self.asrv.vfs, ap_chk, vp_chk, hr["reloc"]) + x = pathmod(self.vfs, ap_chk, vp_chk, hr["reloc"]) if x: zvfs = vfs pdir, _, job["name"], (vfs, rem) = x @@ -4725,7 +4780,7 @@ class Up2k(object): def _snap_reg(self, ptop: str, reg: dict[str, dict[str, Any]]) -> None: now = time.time() - histpath = self.asrv.vfs.histtab.get(ptop) + histpath = self.vfs.histtab.get(ptop) if not histpath: return @@ -4973,7 +5028,7 @@ class Up2k(object): else: fvp, fn = vsplit(fvp) - x = pathmod(self.asrv.vfs, "", req_vp, {"vp": fvp, "fn": fn}) + x = pathmod(self.vfs, "", req_vp, {"vp": fvp, "fn": fn}) if not x: t = "hook_fx(%s): failed to resolve %s based on %s" self.log(t % (act, fvp, req_vp))