Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions gptcache/manager/vector_data/usearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
from typing import List

import numpy as np

from gptcache.manager.vector_data.base import VectorBase, VectorData
from gptcache.utils import import_usearch

import_usearch()
from usearch.index import Index # pylint: disable=C0413


class USearch(VectorBase):
"""vector store: Usearch

:param index_path: the path to Usearch index, defaults to 'index.usearch'.
:type index_path: str
:param dimension: the dimension of the vector, defaults to 0.
:type dimension: int
:param top_k: the number of the vectors results to return, defaults to 1.
:type top_k: int
:param metric: the distance mrtric. 'l2', 'haversine' or other, default = 'ip'
:type metric: str
:param dtype: the quantization dtype, 'f16' or 'f8' if needed, default = 'f32'
:type dtype: str
:param connectivity: the frequency of the connections in the graph, optional
:type connectivity: int
:param expansion_add: the recall of indexing, optional
:type expansion_add: int
:param expansion_search: the quality of search, optional
:type expansion_search: int
"""

def __init__(
self,
index_file_path: str = 'index.usearch',
dimension: int = 64,
top_k: int = 1,
metric: str = 'cos',
dtype: str = 'f32',
connectivity: int = 16,
expansion_add: int = 128,
expansion_search: int = 64,
):
self._index_file_path = index_file_path
self._dimension = dimension
self._top_k = top_k
self._index = Index(
ndim=self._dimension,
metric=metric,
dtype=dtype,
connectivity=connectivity,
expansion_add=expansion_add,
expansion_search=expansion_search,
)
if os.path.isfile(self._index_file_path):
self._index.load(self._index_file_path)

def mul_add(self, datas: List[VectorData]):
data_array, id_array = map(
list, zip(*((data.data, data.id) for data in datas)))
np_data = np.array(data_array).astype('float32')
ids = np.array(id_array, dtype=np.longlong)
self._index.add(ids, np_data)

def search(self, data: np.ndarray, top_k: int = -1):
if top_k == -1:
top_k = self._top_k
np_data = np.array(data).astype('float32').reshape(1, -1)
ids, dist, _ = self._index.search(np_data, top_k)
return list(zip(dist[0], ids[0]))

def rebuild(self, ids=None):
return True

def delete(self, ids):
raise NotImplementedError

def flush(self):
self._index.save(self._index_file_path)

def close(self):
self.flush()

def count(self):
return len(self._index)
7 changes: 6 additions & 1 deletion gptcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"import_fasttext",
"import_huggingface",
"import_uform",
"import_usearch",
"import_torch",
"import_huggingface_hub",
"import_onnxruntime",
Expand Down Expand Up @@ -35,7 +36,7 @@
"softmax",
"import_paddle",
"import_paddlenlp"
]
]

import importlib.util
from typing import Optional
Expand Down Expand Up @@ -81,6 +82,10 @@ def import_uform():
_check_library("uform")


def import_usearch():
_check_library("usearch")


def import_torch():
_check_library("torch")

Expand Down
25 changes: 25 additions & 0 deletions tests/unit_tests/manager/test_usearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import unittest
import numpy as np

from gptcache.manager.vector_data.usearch import USearch
from gptcache.manager.vector_data.base import VectorData


class TestUSearchDB(unittest.TestCase):
def test_normal(self):
size = 1000
dim = 512
top_k = 10

db = USearch(
index_file_path='./index.usearch',
dimension=dim,
top_k=top_k,
metric='cos',
dtype='f32',
)
db.mul_add([VectorData(id=i, data=np.random.rand(dim))
for i in range(size)])
self.assertEqual(len(db.search(np.random.rand(dim))), top_k)
self.assertEqual(db.count(), size)
db.close()