Source code for cyto_ml.data.vectorstore

import logging
import os
import sqlite3
import struct
from abc import ABCMeta, abstractmethod
from typing import List, Optional

import chromadb
import chromadb.api.models.Collection
import numpy as np
import sqlite_vec
from chromadb.config import Settings
from chromadb.errors import UniqueConstraintError

from cyto_ml.data.db_config import SQLITE_SCHEMA

logging.basicConfig(level=logging.INFO)
# TODO make this sensibly configurable, not confusingly hardcoded
STORE = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../../vectors")


[docs] def serialize_f32(vector: List[float]) -> bytes: """serializes a list of floats into a compact "raw bytes" format https://github.com/asg017/sqlite-vec/blob/main/examples/simple-python/demo.py """ return struct.pack("%sf" % len(vector), *vector)
[docs] def deserialize(packed: bytes) -> List[float]: """Inverse of the serialisation method suggested above (e.g. for clustering)""" size = int(len(packed) / 4) return struct.unpack("%sf" % size, packed)
[docs] class VectorStore(metaclass=ABCMeta):
[docs] @abstractmethod def add(self, url: str, embeddings: List[float]) -> None: pass
[docs] @abstractmethod def get(self, url: str) -> List[float]: pass
[docs] @abstractmethod def closest(self, embeddings: List) -> List[float]: pass
[docs] @abstractmethod def embeddings(self) -> List[List]: pass
[docs] @abstractmethod def ids(self) -> List[str]: pass
[docs] class ChromadbStore(VectorStore): client = chromadb.PersistentClient( path=STORE, settings=Settings( anonymized_telemetry=False, ), ) def __init__(self, db_name: str): try: collection = self.client.create_collection( name=db_name, metadata={"hnsw:space": "cosine"}, # default similarity ) except UniqueConstraintError as err: collection = self.client.get_collection(db_name) logging.info(err) self.store = collection
[docs] def add(self, url: str, embeddings: List[float]) -> None: """Add vector to Chromadb""" self.store.add( documents=[url], # we use image location in s3 rather than text content embeddings=[embeddings], # wants a list of lists ids=[url], # wants a list of ids )
[docs] def get(self, url: str) -> list: """Retrieve vector from Chromadb""" record = self.store.get([url], include=["embeddings"]) return record["embeddings"][0]
[docs] def closest(self, url: str, n_results: int = 25) -> List: """Get the N closest identifiers by cosine distance""" embeddings = self.get(url) results = self.store.query(query_embeddings=[embeddings], n_results=n_results) return results["ids"][0] # by index because API assumes query always multiple inputs
[docs] def embeddings(self) -> List[List]: result = self.store.get(include=["embeddings"]) return np.array(result["embeddings"])
[docs] def ids(self) -> List[str]: return self.store.get().get("ids", [])
[docs] class PostgresStore(VectorStore): def __init__(self, db_name: str): self.db_name = db_name
[docs] def add(self, url: str, embeddings: List[float]) -> None: # Implementation for adding vector to Postgres pass
[docs] def get(self, url: str) -> List[float]: # Implementation for retrieving vector from Postgres pass
[docs] def closest(self, embeddings: list, n_results: int = 25) -> List: pass
[docs] def embeddings(self) -> List[List]: pass
[docs] def ids(self) -> List[str]: pass
[docs] class SQLiteVecStore(VectorStore): def __init__(self, db_name: str, embedding_len: Optional[int] = 512, check_same_thread: bool = True): self._check_same_thread = check_same_thread self.embedding_len = embedding_len self.load_ext(db_name) self.load_schema()
[docs] def load_ext(self, db_name: str) -> None: """Load the sqlite extension into our db if needed""" # db_name could be ':memory:' for testing, or a path db = sqlite3.connect(db_name, check_same_thread=self._check_same_thread) db.enable_load_extension(True) sqlite_vec.load(db) db.enable_load_extension(False) self.db = db
[docs] def load_schema(self) -> None: """Load our db schema if needed; Default embedding length is 2048, set at init. Consider SQLAlchemy for this, or a CLI-based way of loading from a file; a list of CREATE TABLE statements feels like a kludge. """ for statement in SQLITE_SCHEMA: query = statement.format(self.embedding_len) try: self.db.execute(query) except sqlite3.OperationalError as err: if "already exists" in str(err): pass else: raise
[docs] def add(self, url: str, embeddings: List[float], classification: Optional[str] = "") -> None: """Add image embeddings to storage. Two tables: * one regular one which holds metadata, with embeddings as floats * one "virtual table" for indexing it by ID with encoded embeddings""" cursor = self.db.cursor() cursor.execute( "INSERT INTO images(url, embedding, classification) VALUES (?, ?, ?)", [url, serialize_f32(embeddings), classification], ) row_id = cursor.lastrowid cursor.execute( "INSERT INTO images_vec(id, embedding) VALUES (?, ?)", [row_id, serialize_f32(embeddings)], ) self.db.commit()
[docs] def get(self, url: str) -> List[float]: result = self.db.execute("select embedding from images where url = ?", [url]).fetchone() if len(result): return result[0] else: return None
[docs] def closest(self, url: str, n_results: int = 25) -> List: """Find and return the N closest examples by distance Accepts an image URL, returns a list ordered by distance """ # See https://til.simonwillison.net/sqlite/sqlite-vec # https://github.com/asg017/sqlite-vec/issues/41 - "limit ?" not guaranteed # Note - stopped returning distance for consistency, but might be useful try: doc_id = self.db.execute("select id from images where url = ?", [url]).fetchone()[0] except IndexError: return None query = """ with image_embedding as ( select embedding as first_embedding from images_vec where id = ? ) select images.url, vec_distance_cosine(images_vec.embedding, first_embedding) as distance from images_vec, image_embedding, images where images_vec.id = images.id order by distance limit ?""" results = self.db.execute(query, [doc_id, n_results]).fetchall() return results # [i for j in results for i in j]
[docs] def labelled(self, label: str, n_results: int = 50) -> List[str]: labelled = self.db.execute( """select url from images where classification = ? limit ?""", (label, n_results) ).fetchall() return [i for j in labelled for i in j]
[docs] def classes(self) -> List[str]: classes = self.db.execute("""select distinct classification from images""").fetchall() return [i for j in classes for i in j]
[docs] def embeddings(self) -> List[List]: embeddings = self.db.execute("""select embedding from images""").fetchall() return [deserialize(i) for j in embeddings for i in j]
[docs] def ids(self) -> List[str]: urls = self.db.execute("""select url from images""").fetchall() return [i for j in urls for i in j]
[docs] def vector_store( store_type: Optional[str] = "chromadb", db_name: Optional[str] = "test_collection", **kwargs ) -> VectorStore: if store_type == "chromadb": return ChromadbStore(db_name, **kwargs) elif store_type == "postgres": return PostgresStore(db_name, **kwargs) elif store_type == "sqlite": return SQLiteVecStore(db_name, **kwargs) else: raise ValueError(f"Unknown store type: {store_type}")