diff --git a/copyparty/up2k.py b/copyparty/up2k.py index 4fc6755e..40725cf3 100644 --- a/copyparty/up2k.py +++ b/copyparty/up2k.py @@ -2546,6 +2546,7 @@ class Up2k(object): svn, srem = self.asrv.vfs.get(svp, uname, True, False, True) svn, srem = svn.get_dbv(srem) sabs = svn.canonical(srem, False) + curs: set["sqlite3.Cursor"] = set() if not srem: raise Pebkac(400, "mv: cannot move a mountpoint") @@ -2553,7 +2554,13 @@ class Up2k(object): st = bos.lstat(sabs) if stat.S_ISREG(st.st_mode) or stat.S_ISLNK(st.st_mode): with self.mutex: - return self._mv_file(uname, svp, dvp) + try: + ret = self._mv_file(uname, svp, dvp, curs) + finally: + for v in curs: + v.connection.commit() + + return ret jail = svn.get_dbv(srem)[0] permsets = [[True, False, True]] @@ -2572,20 +2579,29 @@ class Up2k(object): # the actual check (avoid toctou) raise Pebkac(400, "mv: source folder contains other volumes") - for fn in files: - svpf = "/".join(x for x in [dbv.vpath, vrem, fn[0]] if x) - if not svpf.startswith(svp + "/"): # assert - raise Pebkac(500, "mv: bug at {}, top {}".format(svpf, svp)) + with self.mutex: + try: + for fn in files: + self.db_act = time.time() + svpf = "/".join(x for x in [dbv.vpath, vrem, fn[0]] if x) + if not svpf.startswith(svp + "/"): # assert + raise Pebkac(500, "mv: bug at {}, top {}".format(svpf, svp)) - dvpf = dvp + svpf[len(svp) :] - with self.mutex: - self._mv_file(uname, svpf, dvpf) + dvpf = dvp + svpf[len(svp) :] + self._mv_file(uname, svpf, dvpf, curs) + finally: + for v in curs: + v.connection.commit() + + curs.clear() rmdirs(self.log_func, scandir, True, sabs, 1) rmdirs_up(os.path.dirname(sabs)) return "k" - def _mv_file(self, uname: str, svp: str, dvp: str) -> str: + def _mv_file( + self, uname: str, svp: str, dvp: str, curs: set["sqlite3.Cursor"] + ) -> str: svn, srem = self.asrv.vfs.get(svp, uname, True, False, True) svn, srem = svn.get_dbv(srem) @@ -2643,11 +2659,11 @@ class Up2k(object): self._forget_file(svn.realpath, srem, c1, w, c1 != c2) self._relink(w, svn.realpath, srem, dabs) - c1.connection.commit() + curs.add(c1) if c2: self.db_add(c2, w, drd, dfn, ftime, fsize, ip or "", at or 0) - c2.connection.commit() + curs.add(c2) else: self.log("not found in src db: [{}]".format(svp))