From 4535a81617032294737b1b17f810f794a728d4bc Mon Sep 17 00:00:00 2001 From: ed Date: Mon, 24 Oct 2022 13:44:19 +0200 Subject: [PATCH] smb: add up2k-indexing on write --- copyparty/smbd.py | 49 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 8 deletions(-) diff --git a/copyparty/smbd.py b/copyparty/smbd.py index 1e07655b..ef805a47 100644 --- a/copyparty/smbd.py +++ b/copyparty/smbd.py @@ -51,6 +51,7 @@ class SMB(object): self.args = hub.args self.asrv = hub.asrv self.log_func = hub.log + self.files: dict[int, tuple[float, str]] = {} handler = HLog(hub.log) lvl = logging.DEBUG if self.args.smb_dbg else logging.INFO @@ -72,6 +73,7 @@ class SMB(object): setattr(fos, k, getattr(os, k)) except: pass + fos.close = self._close fos.listdir = self._listdir fos.open = self._open fos.stat = self._stat @@ -144,16 +146,47 @@ class SMB(object): def _open( self, vpath: str, flags: int, chmod: int = 0o777, *a: Any, **ka: Any ) -> Any: - if not self.args.smbw: - ok = os.O_RDONLY - if ANYWIN: - ok |= os.O_BINARY + f_ro = os.O_RDONLY + if ANYWIN: + f_ro |= os.O_BINARY - if flags != ok: - logging.info("blocked write to %s", vpath) - raise Exception("read-only") + readonly = flags == f_ro - return bos.open(self._v2a("open", vpath, *a)[1], flags, chmod, *a, **ka) + if not self.args.smbw and readonly: + logging.info("blocked write to %s", vpath) + raise Exception("read-only") + + ret = bos.open(self._v2a("open", vpath, *a)[1], flags, chmod, *a, **ka) + if not readonly: + now = time.time() + nf = len(self.files) + if nf > 10: + oldest = min([x[0] for x in self.files.values()]) + cutoff = oldest + (now - oldest) / 2 + self.files = {k: v for k, v in self.files.items() if v[0] > cutoff} + logging.info("was tracking %d files, now %d", nf, len(self.files)) + + self.files[ret] = (now, vpath) + + return ret + + def _close(self, fd: int) -> None: + os.close(fd) + if fd not in self.files: + return + + _, vp = self.files.pop(fd) + vp, fn = os.path.split(vp) + vfs, rem = self.hub.asrv.vfs.get(vp, LEELOO_DALLAS, False, True) + vfs, rem = vfs.get_dbv(rem) + self.hub.up2k.hash_file( + vfs.realpath, + vfs.flags, + rem, + fn, + "1.7.6.2", + time.time(), + ) def _stat(self, vpath: str, *a: Any, **ka: Any) -> os.stat_result: return bos.stat(self._v2a("stat", vpath, *a)[1], *a, **ka)