diff --git a/copyparty/u2idx.py b/copyparty/u2idx.py index 45b479c5..5eb13adc 100644 --- a/copyparty/u2idx.py +++ b/copyparty/u2idx.py @@ -53,6 +53,11 @@ class U2idx(object): self.log("your python does not have sqlite3; searching will be disabled") return + if self.args.srch_icase: + self._open_db = self._open_db_icase + else: + self._open_db = self._open_db_std + assert sqlite3 # type: ignore # !rm self.active_id = "" @@ -69,6 +74,16 @@ class U2idx(object): def log(self, msg: str, c: Union[int, str] = 0) -> None: self.log_func("u2idx", msg, c) + def _open_db_std(self, *args, **kwargs): + assert sqlite3 # type: ignore # !rm + kwargs["check_same_thread"] = False + return sqlite3.connect(*args, **kwargs) + + def _open_db_icase(self, *args, **kwargs): + db = self._open_db_std(*args, **kwargs) + db.create_function("casefold", 1, lambda x: x.casefold() if x else x) + return db + def shutdown(self) -> None: if not HAVE_SQLITE3: return @@ -148,8 +163,7 @@ class U2idx(object): uri = "" try: uri = "{}?mode=ro&nolock=1".format(Path(db_path).as_uri()) - db = sqlite3.connect(uri, timeout=2, uri=True, check_same_thread=False) - cur = db.cursor() + cur = self._open_db(uri, timeout=2, uri=True).cursor() cur.execute('pragma table_info("up")').fetchone() self.log("ro: %r" % (db_path,)) except: @@ -160,7 +174,7 @@ class U2idx(object): if not cur: # on windows, this steals the write-lock from up2k.deferred_init -- # seen on win 10.0.17763.2686, py 3.10.4, sqlite 3.37.2 - cur = sqlite3.connect(db_path, timeout=2, check_same_thread=False).cursor() + cur = self._open_db(db_path, timeout=2).cursor() self.log("opened %r" % (db_path,)) self.cur[ptop] = cur @@ -173,6 +187,8 @@ class U2idx(object): if not HAVE_SQLITE3: return [], [], False + icase = self.args.srch_icase + q = "" v: Union[str, int] = "" va: list[Union[str, int]] = [] @@ -232,9 +248,13 @@ class U2idx(object): elif v == "path": v = "trim(?||up.rd,'/')" va.append("\nrd") + if icase: + v = "casefold(%s)" % (v,) elif v == "name": v = "up.fn" + if icase: + v = "casefold(%s)" % (v,) elif v == "tags" or ptn_mt.match(v): have_mt = True @@ -285,6 +305,12 @@ class U2idx(object): tail = "||'%'" v = v[:-1] + if icase and "casefold(" in q: + try: + v = unicode(v).casefold() + except: + v = unicode(v).lower() + q += " {}?{} ".format(head, tail) va.append(v) is_key = True