diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d442a6a8..d113b1f72 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Rename `_Predictor` to `Model` - Allow developers to write `Listeners` and `Graph` in a single formalism - Change unittesting framework to pure configuration (no patching configs) +- Adding `PostgresDataBackend` for `Pgvector` integration #### Bug Fixes - Fixed a bug in refresh_after_insert for listeners with select None diff --git a/examples/pgvector.ipynb b/examples/pgvector.ipynb new file mode 100644 index 000000000..0de6762b2 --- /dev/null +++ b/examples/pgvector.ipynb @@ -0,0 +1,591 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "be320b36", + "metadata": {}, + "source": [ + "# Postgres + Pgvector (HNSW and IVFflat Indexing)" + ] + }, + { + "cell_type": "markdown", + "id": "3d6841b7", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "\n", + "Before starting the implementation, make sure you have the required libraries installed by running the following commands:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6d47e1a", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# !pip install superduperdb\n", + "# !pip install vllm\n", + "# !pip install sentence_transformers numpy==1.24.4\n", + "# !pip install 'ibis-framework[postgres]'\n", + "# !pip install pgvector\n", + "# !pip install psycopg2 " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61b8392b", + "metadata": {}, + "outputs": [], + "source": [ + "!rm -rf .superduperdb/ && mkdir -p .superduperdb" + ] + }, + { + "cell_type": "markdown", + "id": "a44fba27", + "metadata": {}, + "source": [ + "## Connect to datastore \n", + "\n", + "First, we need to establish a connection to a Postgres datastore via SuperDuperDB. You can configure the `Postgres_URI` based on your specific setup. \n", + "Here are some examples of postgres URIs:\n", + "\n", + "* For testing (default connection): `postgres://test`\n", + "* Local postgres instance: `postgres://localhost:27017`\n", + "* postgres with authentication: `postgres://superduper:superduper@postgres:27017/documents`\n", + "* postgres Atlas: `postgres+srv://:@/`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4e7a535", + "metadata": {}, + "outputs": [], + "source": [ + "from superduperdb.base.config import VectorSearch, Compute" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f9fb3c1", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from superduperdb import superduper\n", + "from superduperdb.backends.ibis import Table\n", + "import os\n", + "from superduperdb.backends.ibis.field_types import dtype\n", + "from superduperdb.ext.pillow import pil_image\n", + "from superduperdb import Schema\n", + "\n", + "connection_uri = \"postgresql://postgres:test@localhost:8000/qa\"\n", + "\n", + "\n", + "# It just super dupers your database\n", + "db = superduper(\n", + " connection_uri,\n", + " metadata_store='sqlite:///.superduperdb/metadata.sqlite',\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6f93d7d", + "metadata": {}, + "outputs": [], + "source": [ + "!python -m superduperdb info" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f4c6a9d", + "metadata": {}, + "outputs": [], + "source": [ + "import glob\n", + "import re\n", + "\n", + "ROOT = '../docs/hr/content/docs/'\n", + "\n", + "STRIDE = 3 # stride in numbers of lines\n", + "WINDOW = 25 # length of window in numbers of lines\n", + "\n", + "files = sorted(glob.glob(f'{ROOT}/**/*.md', recursive=True))\n", + "\n", + "def get_chunk_link(chunk, file_name):\n", + " # Get the original link of the chunk\n", + " file_link = file_name[:-3].replace(ROOT, 'https://docs.superduperdb.com/docs/docs/')\n", + " # If the chunk has subtitles, the link to the first subtitle will be used first.\n", + " first_title = (re.findall(r'(^|\\n)## (.*?)\\n', chunk) or [(None, None)])[0][1]\n", + " if first_title:\n", + " # Convert subtitles and splice URLs\n", + " first_title = first_title.lower()\n", + " first_title = re.sub(r'[^a-zA-Z0-9]', '-', first_title)\n", + " file_link = file_link + '#' + first_title\n", + " return file_link\n", + "\n", + "def create_chunk_and_links(file, file_prefix=ROOT):\n", + " with open(file, 'r') as f:\n", + " lines = f.readlines()\n", + " if len(lines) > WINDOW:\n", + " chunks = ['\\n'.join(lines[i: i + WINDOW]) for i in range(0, len(lines), STRIDE)]\n", + " else:\n", + " chunks = ['\\n'.join(lines)]\n", + " return [{'txt': chunk, 'link': get_chunk_link(chunk, file)} for chunk in chunks]\n", + "\n", + "\n", + "all_chunks_and_links = sum([create_chunk_and_links(file) for file in files], [])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72785e18", + "metadata": {}, + "outputs": [], + "source": [ + "# Use !curl to download the 'superduperdb_docs.json' file\n", + "!curl -O https://datas-public.s3.amazonaws.com/superduperdb_docs.json\n", + "\n", + "import json\n", + "from IPython.display import Markdown\n", + "\n", + "# Open the downloaded JSON file and load its contents into the 'chunks' variable\n", + "with open('superduperdb_docs.json') as f:\n", + " all_chunks_and_links = json.load(f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "afba211d", + "metadata": {}, + "outputs": [], + "source": [ + "all_chunks_and_links[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1cd41263", + "metadata": {}, + "outputs": [], + "source": [ + "new_all_chunks_and_links = list()\n", + "for i, e in enumerate(all_chunks_and_links):\n", + " e['id'] = i\n", + " new_all_chunks_and_links.append(e)" + ] + }, + { + "cell_type": "markdown", + "id": "97113997", + "metadata": {}, + "source": [ + "## Define Schema and Create table\n", + "\n", + "For this use-case, you need a table with images and another table with text. SuperDuperDB extends standard SQL functionality, allowing developers to define their own data types through the `Encoder` abstraction." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b5a2243", + "metadata": {}, + "outputs": [], + "source": [ + "Schema(\n", + " 'questiondocs-schema',\n", + " fields={'id': dtype(str), 'txt': dtype(str), 'link': dtype(str)},\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35365286", + "metadata": {}, + "outputs": [], + "source": [ + "# \n", + "# Define the 'captions' table\n", + "table = Table(\n", + " 'questiondocs',\n", + " primary_id='id',\n", + " schema=Schema(\n", + " 'questiondocs-schema',\n", + " fields={'id': dtype(str), 'txt': dtype(str), 'link': dtype(str)},\n", + " )\n", + ")\n", + "\n", + "\n", + "\n", + "# Add the 'captions' and 'images' tables to the SuperDuperDB database\n", + "db.add(table)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24d34dd5", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa9bbad7", + "metadata": {}, + "outputs": [], + "source": [ + "new_all_chunks_and_links_df = pd.DataFrame(new_all_chunks_and_links)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ca8f699", + "metadata": {}, + "outputs": [], + "source": [ + "df = new_all_chunks_and_links_df.astype(str)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de8596cf", + "metadata": {}, + "outputs": [], + "source": [ + "from superduperdb.base.document import Document as D\n" + ] + }, + { + "cell_type": "raw", + "id": "b67d92cf", + "metadata": {}, + "source": [ + "table.insert(df[['id', 'txt', 'link']])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "386d1595", + "metadata": {}, + "outputs": [], + "source": [ + "insert = table.insert(\n", + " [\n", + " D(\n", + " {\n", + " 'id': d['id'],\n", + " 'txt': d['txt'],\n", + " 'link': d['link'],\n", + " }\n", + " )\n", + " for i, d in df.iterrows()\n", + " ]\n", + " )\n", + "_ = db.execute(insert)" + ] + }, + { + "cell_type": "raw", + "id": "5aeb139c", + "metadata": { + "scrolled": true + }, + "source": [ + "_ = db.execute(table.insert(df[['id', 'txt', 'link']]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ee273e7", + "metadata": {}, + "outputs": [], + "source": [ + "q = table.select('txt', 'link')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "097b7ea6", + "metadata": {}, + "outputs": [], + "source": [ + "result = db.execute(q)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34ae0e4a", + "metadata": {}, + "outputs": [], + "source": [ + "result[0]" + ] + }, + { + "cell_type": "markdown", + "id": "6b48b0b3", + "metadata": {}, + "source": [ + "A `Model` is a wrapper around a self-built or ecosystem model, such as `torch`, `transformers`, `openai`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d03f428a", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from superduperdb import vector\n", + "vector(shape=(1024,))" + ] + }, + { + "cell_type": "markdown", + "id": "dbefeae8", + "metadata": {}, + "source": [ + "### Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "710c80e9", + "metadata": {}, + "outputs": [], + "source": [ + "import sentence_transformers\n", + "from superduperdb.ext.sentence_transformers import SentenceTransformer\n", + "from superduperdb.ext.numpy import array" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7db75570", + "metadata": {}, + "outputs": [], + "source": [ + "model = SentenceTransformer(\n", + " identifier=\"embedding\",\n", + " object=sentence_transformers.SentenceTransformer(\"BAAI/bge-large-en-v1.5\"),\n", + " postprocess=lambda x: x.tolist(),\n", + " datatype=vector(shape=(1024,)),\n", + " predict_kwargs={\"show_progress_bar\": True},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ef3fb79", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "vector = model.predict_one('This is a test')\n", + "print('vector size: ', len(vector))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09ba0e62", + "metadata": {}, + "outputs": [], + "source": [ + "vector" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae2c8134", + "metadata": {}, + "outputs": [], + "source": [ + "# Import the Listener class from the superduperdb module\n", + "from superduperdb import Listener\n", + "\n", + "\n", + "# Create a Listener instance with the specified model, key, and selection criteria\n", + "listener1 = Listener(\n", + " model=model, # The model to be used for listening\n", + " key='txt', # The key field in the documents to be processed by the model\n", + " select=table.select('id', 'txt'), # The selection criteria for the documents\n", + " predict_kwargs={'max_chunk_size': 3000},\n", + " identifier='listener1'\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "874921eb", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "db.add(listener1)" + ] + }, + { + "cell_type": "markdown", + "id": "b26a2742", + "metadata": {}, + "source": [ + "## HNSW (Hierarchical Navigable Small Worlds graph) Indexing\n", + "\n", + "1. HNSW Indexing - Multi layer graph structure\n", + "2. IVFFlat Indexing - is based on clustering\n", + "\n", + "\n", + "> Note : `indexing_measure` and `measure` both should use same similarity approaches. Otherwise it will go for sequential scanning." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24bdc026", + "metadata": {}, + "outputs": [], + "source": [ + "from superduperdb import VectorIndex\n", + "from superduperdb.vector_search.postgres import PostgresVectorSearcher, HNSW, IVFFlat\n", + "\n", + "hnsw_indexing = HNSW(m=16, ef_construction=64, ef_search=49)\n", + "ivfflat_indexing = IVFFlat(lists=100, probes=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0056bd2", + "metadata": {}, + "outputs": [], + "source": [ + "vi = VectorIndex(\n", + " identifier='my-index', # Unique identifier for the VectorIndex\n", + " indexing_listener=listener1, # Listener to be used for indexing documents\n", + " measure='cosine',\n", + " indexing = hnsw_indexing,\n", + " indexing_measure = 'vector_cosine_ops'\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93636671", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "jobs, _ = db.add(vi)" + ] + }, + { + "cell_type": "markdown", + "id": "ea3293b8", + "metadata": {}, + "source": [ + "## Inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b1365988", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "%%time\n", + "from superduperdb.backends.ibis import Table\n", + "from superduperdb import Document as D\n", + "from IPython.display import *\n", + "\n", + "# Define the query for the search\n", + "query = 'Code snippet how to create a `VectorIndex` with a torchvision model'\n", + "# query = 'can you explain vector-indexes with `superduperdb`?'\n", + "\n", + "# Execute a search using SuperDuperDB to find documents containing the specified query\n", + "result = db.execute(\n", + " query=table.like(D({'txt': query}), vector_index='my-index', n=5).select('id', 'txt', 'link')\n", + ")\n", + "\n", + "# Display a horizontal rule to separate results\n", + "display(Markdown('---'))\n", + "\n", + "# Display each document's 'txt' field and separate them with a horizontal rule\n", + "for r in result:\n", + " display(Markdown(r['txt']))\n", + " display(r['link'])\n", + " display(Markdown('---'))" + ] + }, + { + "cell_type": "markdown", + "id": "f6a3c179", + "metadata": {}, + "source": [ + "## Future Works\n", + "1. `Ibis` doesn't support `pgvector`. and want to make it supportable for that `pgvector`." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/superduperdb/backends/base/backends.py b/superduperdb/backends/base/backends.py index 2d653f532..334a85087 100644 --- a/superduperdb/backends/base/backends.py +++ b/superduperdb/backends/base/backends.py @@ -10,6 +10,8 @@ from superduperdb.vector_search.atlas import MongoAtlasVectorSearcher from superduperdb.vector_search.in_memory import InMemoryVectorSearcher from superduperdb.vector_search.lance import LanceVectorSearcher +from superduperdb.vector_search.postgres import PostgresVectorSearcher + data_backends = { 'mongodb': MongoDataBackend, @@ -30,6 +32,7 @@ 'lance': LanceVectorSearcher, 'in_memory': InMemoryVectorSearcher, 'mongodb+srv': MongoAtlasVectorSearcher, + 'pg_vector': PostgresVectorSearcher } CONNECTIONS = { diff --git a/superduperdb/backends/ibis/query.py b/superduperdb/backends/ibis/query.py index 89ac746be..a9c4c0ac0 100644 --- a/superduperdb/backends/ibis/query.py +++ b/superduperdb/backends/ibis/query.py @@ -3,6 +3,9 @@ import re import types import typing as t +from superduperdb import CFG +from pgvector.psycopg2 import register_vector +import psycopg2 import pandas @@ -49,25 +52,56 @@ def _model_update_impl( outputs: t.Sequence[t.Any], flatten: bool = False, ): - if flatten: - raise NotImplementedError('Flatten not yet supported for ibis') + if CFG.cluster.vector_search.type == 'in_memory': + if flatten: + raise NotImplementedError('Flatten not yet supported for ibis') - if not outputs: - return - - table_records = [] - for ix in range(len(outputs)): - d = { - '_input_id': str(ids[ix]), - 'output': outputs[ix], - } - table_records.append(d) + if not outputs: + return - for r in table_records: - if isinstance(r['output'], dict) and '_content' in r['output']: - r['output'] = r['output']['_content']['bytes'] + table_records = [] + for ix in range(len(outputs)): + d = { + '_input_id': str(ids[ix]), + 'output': outputs[ix], + } + table_records.append(d) + + for r in table_records: + if isinstance(r['output'], dict) and '_content' in r['output']: + r['output'] = r['output']['_content']['bytes'] + + db.databackend.insert(f'_outputs.{predict_id}', table_records) + + elif CFG.cluster.vector_search.type == 'pg_vector': + # Connect to your PostgreSQL database + conn = psycopg2.connect(CFG.cluster.vector_search.uri) + register_vector(conn) + table_name = f'_outputs.{predict_id}' + with conn.cursor() as cursor: + cursor.execute('CREATE EXTENSION IF NOT EXISTS vector') + cursor.execute(f"""DROP TABLE IF EXISTS "{table_name}";""") + cursor.execute( + f"""CREATE TABLE "{table_name}" ( + _input_id VARCHAR PRIMARY KEY, + output vector(1024), + _fold VARCHAR + ); + """ + ) + for ix in range(len(outputs)): + try: + cursor.execute( + f"""INSERT INTO "{table_name}" (_input_id, output) VALUES (%s, %s);""", + [str(ids[ix]), outputs[ix]] + ) + except: + pass - db.databackend.insert(f'_outputs.{predict_id}', table_records) + # Commit the transaction + conn.commit() + # Close the connection + conn.close() class IbisBackendError(Exception): @@ -183,7 +217,7 @@ def compile(self, db: 'Datalayer', tables: t.Optional[t.Dict] = None): if tables is None: tables = {} if table_id not in tables: - tables[table_id] = db.databackend.conn.table(table_id) + tables[table_id] = db.databackend.conn.tables.get(table_id) return self.query_linker.compile(db, tables=tables) def get_all_tables(self): diff --git a/superduperdb/base/build.py b/superduperdb/base/build.py index b91832651..b8fcef91a 100644 --- a/superduperdb/base/build.py +++ b/superduperdb/base/build.py @@ -117,6 +117,18 @@ def _build_databackend_impl(uri, mapping, type: str = 'data_backend'): name = uri.split('/')[-1] conn = mongomock.MongoClient() return mapping['mongodb'](conn, name) + + elif uri.startswith('postgres://') or uri.startswith("postgresql://"): + name = uri.split('//')[0] + if type == 'data_backend': + ibis_conn = ibis.connect(uri) + return mapping['ibis'](ibis_conn, name) + else: + assert type == 'metadata' + from sqlalchemy import create_engine + + sql_conn = create_engine(uri) + return mapping['sqlalchemy'](sql_conn, name) elif uri.endswith('.csv'): if type == 'metadata': diff --git a/superduperdb/base/superduper.py b/superduperdb/base/superduper.py index 9a2363be2..ed2e603e8 100644 --- a/superduperdb/base/superduper.py +++ b/superduperdb/base/superduper.py @@ -35,6 +35,9 @@ def _auto_identify_connection_string(item: str, **kwargs) -> t.Any: elif item.startswith('mongodb+srv://') and 'mongodb.net' in item: kwargs['data_backend'] = item + elif item.startswith('postgres://') or item.startswith('postgresql://'): + kwargs['data_backend'] = item + elif item.endswith('.csv'): kwargs['data_backend'] = item diff --git a/superduperdb/components/vector_index.py b/superduperdb/components/vector_index.py index 9552a14ff..b87c82c08 100644 --- a/superduperdb/components/vector_index.py +++ b/superduperdb/components/vector_index.py @@ -4,6 +4,10 @@ import numpy as np from overrides import override +from superduperdb import CFG +import psycopg2 +from pgvector.psycopg2 import register_vector + from superduperdb.base.datalayer import Datalayer from superduperdb.base.document import Document from superduperdb.components.component import Component @@ -16,6 +20,7 @@ from superduperdb.misc.special_dicts import MongoStyleDict from superduperdb.vector_search.base import VectorIndexMeasureType from superduperdb.vector_search.update_tasks import copy_vectors +from superduperdb.vector_search.postgres import PostgresIndexing, HNSW, IVFFlat KeyType = t.Union[str, t.List, t.Dict] if t.TYPE_CHECKING: @@ -42,6 +47,8 @@ class VectorIndex(Component): compatible_listener: t.Optional[Listener] = None measure: VectorIndexMeasureType = VectorIndexMeasureType.cosine metric_values: t.Optional[t.Dict] = dc.field(default_factory=dict) + indexing : t.Optional[HNSW | IVFFlat] = None, + indexing_measure : t.Optional[PostgresIndexing] = PostgresIndexing.cosine @override def on_load(self, db: Datalayer) -> None: @@ -54,6 +61,26 @@ def on_load(self, db: Datalayer) -> None: self.compatible_listener = t.cast( Listener, db.load('listener', self.compatible_listener) ) + if CFG.cluster.vector_search.type == "pg_vector": + conn = psycopg2.connect(CFG.cluster.vector_search.uri) + table_name = f"_outputs.{self.indexing_listener.predict_id}" + with conn.cursor() as cursor: + if self.indexing.name == 'hnsw': + + cursor.execute(f"""CREATE INDEX ON "{table_name}" + USING {self.indexing.name} (output {self.indexing_measure}) + WITH (m = {self.indexing.m}, ef_construction = {self.indexing.ef_construction});""") + + cursor.execute("""SET %s.ef_search = %s;""" % (self.indexing.name, self.indexing.ef_search)) + elif self.indexing.name == 'ivfflat': + cursor.execute(f"""CREATE INDEX ON "{table_name}" + USING %s (output %s) + WITH (lists = %s);""" % (self.indexing.name, self.indexing_measure, self.indexing.lists)) + + cursor.execute("""SET %s.probes = %s;""" % (self.indexing.name, self.indexing.probes)) + conn.commit() + conn.close() + def get_vector( self, diff --git a/superduperdb/vector_search/interface.py b/superduperdb/vector_search/interface.py index 3b9e6c54d..91cc7fa50 100644 --- a/superduperdb/vector_search/interface.py +++ b/superduperdb/vector_search/interface.py @@ -16,15 +16,16 @@ def __init__(self, db: 'Datalayer', vector_searcher, vector_index: str): self.vector_index = vector_index if CFG.cluster.vector_search.uri is not None: - if not db.server_mode: - request_server( - service='vector_search', - endpoint='create/search', - args={ - 'vector_index': self.vector_index, - }, - type='get', - ) + if CFG.cluster.vector_search.type != 'pg_vector': + if not db.server_mode: + request_server( + service='vector_search', + endpoint='create/search', + args={ + 'vector_index': self.vector_index, + }, + type='get', + ) def __len__(self): return len(self.searcher) @@ -103,13 +104,14 @@ def find_nearest_from_array( :param n: number of nearest vectors to return """ if CFG.cluster.vector_search.uri is not None: - response = request_server( - service='vector_search', - data=h, - endpoint='query/search', - args={'vector_index': self.vector_index, 'n': n}, - ) - return response['ids'], response['scores'] + if CFG.cluster.vector_search.type != 'pg_vector': + response = request_server( + service='vector_search', + data=h, + endpoint='query/search', + args={'vector_index': self.vector_index, 'n': n}, + ) + return response['ids'], response['scores'] return self.searcher.find_nearest_from_array(h=h, n=n, within_ids=within_ids) diff --git a/superduperdb/vector_search/postgres.py b/superduperdb/vector_search/postgres.py new file mode 100644 index 000000000..f090955c3 --- /dev/null +++ b/superduperdb/vector_search/postgres.py @@ -0,0 +1,258 @@ +import json +import typing as t +import numpy +import dataclasses as dc +from pgvector.psycopg2 import register_vector +import psycopg2 + + +from superduperdb import CFG, logging +if t.TYPE_CHECKING: + from superduperdb.components.vector_index import VectorIndex +from superduperdb.components.model import APIModel, Model + + + +from superduperdb.vector_search.base import BaseVectorSearcher, VectorItem, VectorIndexMeasureType + +@dc.dataclass(kw_only=True) +class PostgresIndexing: + cosine = "vector_cosine_ops" + l2 = "vector_l2_ops" + inner_product = "vector_ip_ops" + +@dc.dataclass(kw_only=True) +class IVFFlat(PostgresIndexing): + """ + An IVFFlat index divides vectors into lists, and then searches a subset of those lists that are closest to the query vector. + It has faster build times and uses less memory than HNSW, but has lower query performance (in terms of speed-recall tradeoff). + + :param lists + :param probes + """ + def __init__(self, lists: t.Optional[int] = 100, probes: t.Optional[int] = 1): + self.name = "ivfflat" + self.lists = lists + self.probes = probes + +@dc.dataclass(kw_only=True) +class HNSW(PostgresIndexing): + """ + An HNSW index creates a multilayer graph. It has better query performance than IVFFlat (in terms of speed-recall tradeoff), + but has slower build times and uses more memory. Also, an index can be created without any data in the table + since there isn’t a training step like IVFFlat. + + :param m: the max number of connections per layer + :param ef_construction: the size of the dynamic candidate list for constructing the graph + """ + def __init__(self, m: t.Optional[int] = 16, ef_construction: t.Optional[int] = 64, ef_search: t.Optional[int] = 40): + self.name = "hnsw" + self.m = m + self.ef_construction = ef_construction + self.ef_search: ef_search = ef_search + + +class PostgresVectorSearcher(BaseVectorSearcher): + """ + Implementation of a vector index using the ``pgvector`` library. + :param identifier: Unique string identifier of index + :param dimensions: Dimension of the vector embeddings in the Lance dataset + :param uri: connection string to postgres + :param h: ``torch.Tensor`` + :param index: list of IDs + :param measure: measure to assess similarity + """ + + def __init__( + self, + identifier: str, + dimensions: int, + uri: str, + h: t.Optional[numpy.ndarray] = None, + index: t.Optional[t.List[str]] = None, + measure: t.Optional[str] = VectorIndexMeasureType.cosine, + indexing : t.Optional[HNSW | IVFFlat] = None, + indexing_measure : t.Optional[PostgresIndexing] = PostgresIndexing.cosine + ): + self.connection = psycopg2.connect(uri) + self.dimensions = dimensions + self.identifier = identifier + self.measure = measure + self.measure_query = self.get_measure_query() + self.indexing = indexing + self.indexing_measure = indexing_measure + with self.connection.cursor() as cursor: + cursor.execute('CREATE EXTENSION IF NOT EXISTS vector') + cursor.execute( + 'CREATE TABLE IF NOT EXISTS "%s" (id varchar, txt VARCHAR, output vector(%d))' + % (self.identifier, self.dimensions) + ) + self.connection.commit() + if h: + self._create_or_append_to_dataset(h, index) + + + def __len__(self): + with self.connection.cursor() as curr: + length = curr.execute( + 'SELECT COUNT(*) FROM %s' % self.identifier + ).fetchone()[0] + return length + + def get_measure_query(self): + if self.measure.value == "l2": + return "output <-> '%s'" + elif self.measure.value == "dot": + return "(output <#> '%s') * -1" + elif self.measure.value == "cosine": + return "(output <=> '%s')" + else: + raise NotImplementedError("Unrecognized measure format") + + + def _create_or_append_to_dataset(self, vectors, ids): + with self.connection.cursor() as cursor: + for id_, vector in zip(ids, vectors): + try: + cursor.execute( + "INSERT INTO %s (id, output) VALUES (%s, '%s');" % (self.identifier, id_, vector) + ) + except Exception as e: + pass + self.connection.commit() + + def _create_index(self): + print("_create_index") + with self.connection.cursor() as cursor: + if self.indexing.name == 'hnsw': + print("hnsw") + cursor.execute("""CREATE INDEX ON %s + USING %s (output %s) + WITH (m = %s, ef_construction = %s);""" % (self.identifier, self.indexing.name, self.indexing_measure, self.indexing.m, self.indexing.ef_construction)) + + cursor.execute("""SET %s.ef_search = %s;""" % (self.indexing.name, self.indexing.ef_search)) + elif self.indexing.name == 'ivfflat': + cursor.execute("""CREATE INDEX ON %s + USING %s (output %s) + WITH (lists = %s);""" % (self.identifier, self.indexing.name, self.indexing_measure, self.indexing.lists)) + + cursor.execute("""SET %s.probes = %s;""" % (self.indexing.name, self.indexing.probes)) + print("_create_index") + self.connection.commit() + + + def add(self, items: t.Sequence[VectorItem]) -> None: + """ + Add items to the index. + :param items: t.Sequence of VectorItems + """ + ids = [item.id for item in items] + vectors = [item.vector for item in items] + self._create_or_append_to_dataset(vectors, ids) + + if self.indexing: + self._create_index() + + + def delete(self, ids: t.Sequence[str]) -> None: + """ + Remove items from the index + :param ids: t.Sequence of ids of vectors. + """ + with self.connection.cursor() as curr: + for id_vector in ids: + curr.execute( + "DELETE FROM %s WHERE id = '%s'" % (self.identifier, id_vector) + ) + self.connection.commit() + + + def find_nearest_from_id( + self, + _id, + n: int = 100, + within_ids: t.Sequence[str] = (), + ) -> t.Tuple[t.List[str], t.List[float]]: + """ + Find the nearest vectors to the vector with the given id. + :param _id: id of the vector + :param n: number of nearest vectors to return + """ + with self.connection.cursor() as curr: + curr.execute( + """ + SELECT output + FROM %s + WHERE id = '%s'""" + % (self.identifier, _id) + ) + h = curr.fetchone()[0] + return self.find_nearest_from_array(h, n, within_ids) + + def find_nearest_from_array( + self, + h: numpy.typing.ArrayLike, + n: int = 100, + within_ids: t.Sequence[str] = (), + ) -> t.Tuple[t.List[str], t.List[float]]: + """ + Find the nearest vectors to the given vector. + :param h: vector + :param n: number of nearest vectors to return + """ + # h = self.to_numpy(h)[None, :] + if len(within_ids) == 0: + condition = "1=1" + else: + within_ids_str = ', '.join([f"'{i}'" for i in within_ids]) + condition = f"id in ({within_ids_str})" + query_search_nearest = f""" + SELECT _input_id, {self.measure_query} as distance + FROM "%s" + WHERE %s + ORDER BY distance ASC + LIMIT %d + """ + + with self.connection.cursor() as curr: + curr.execute( + query_search_nearest % (list(h), self.identifier, condition, n) + ) + nearest_items = curr.fetchall() + ids = [row[0] for row in nearest_items] + scores = [row[1] for row in nearest_items] + return ids, scores + + @classmethod + def from_component(cls, vi: 'VectorIndex'): + from superduperdb.components.listener import Listener + from superduperdb.components.model import ObjectModel + + assert isinstance(vi.indexing_listener, Listener) + collection = vi.indexing_listener.select.table_or_collection.identifier + + + indexing_key = vi.indexing_listener.key + + assert isinstance( + indexing_key, str + ), 'Only single key is support for atlas search' + if indexing_key.startswith('_outputs'): + indexing_key = indexing_key.split('.')[1] + assert isinstance(vi.indexing_listener.model, Model) or isinstance( + vi.indexing_listener.model, APIModel + ) + assert isinstance(collection, str), 'Collection is required to be a string' + indexing_model = vi.indexing_listener.model.identifier + + indexing_version = vi.indexing_listener.model.version + + output_path = f'_outputs.{vi.indexing_listener.predict_id}' + print(output_path) + + return PostgresVectorSearcher( + uri=CFG.cluster.vector_search.uri, + identifier=output_path, + dimensions=vi.dimensions, + measure=VectorIndexMeasureType.cosine, + ) \ No newline at end of file diff --git a/superduperdb/vector_search/update_tasks.py b/superduperdb/vector_search/update_tasks.py index c5f7abd3a..cb9dd71ee 100644 --- a/superduperdb/vector_search/update_tasks.py +++ b/superduperdb/vector_search/update_tasks.py @@ -6,6 +6,7 @@ from superduperdb.base.serializable import Serializable from superduperdb.misc.special_dicts import MongoStyleDict from superduperdb.vector_search.base import VectorItem +from superduperdb import CFG def delete_vectors( @@ -40,48 +41,49 @@ def copy_vectors( :param db: A ``DB`` instance. """ - vi = db.vector_indices[vector_index] - if isinstance(query, dict): - # ruff: noqa: E501 - query: CompoundSelect = Serializable.decode(query) # type: ignore[no-redef] - assert isinstance(query, CompoundSelect) - if not ids: - select = query - else: - select = query.select_using_ids(ids) - docs = db.select(select) - docs = [doc.unpack() for doc in docs] - key = vi.indexing_listener.key - if '_outputs.' in key: - key = key.split('.')[1] - # TODO: Refactor the below logic - vectors = [] - if isinstance(db.databackend, MongoDataBackend): - vectors = [ - { - 'vector': MongoStyleDict(doc)[ - f'_outputs.{vi.indexing_listener.predict_id}' - ], - 'id': str(doc['_id']), - } - for doc in docs - ] - elif isinstance(db.databackend, IbisDataBackend): - docs = db.execute(select.outputs(vi.indexing_listener.predict_id)) - from superduperdb.backends.ibis.data_backend import INPUT_KEY + if CFG.cluster.vector_search.type != 'pg_vector': + vi = db.vector_indices[vector_index] + if isinstance(query, dict): + # ruff: noqa: E501 + query: CompoundSelect = Serializable.decode(query) # type: ignore[no-redef] + assert isinstance(query, CompoundSelect) + if not ids: + select = query + else: + select = query.select_using_ids(ids) + docs = db.select(select) + docs = [doc.unpack() for doc in docs] + key = vi.indexing_listener.key + if '_outputs.' in key: + key = key.split('.')[1] + # TODO: Refactor the below logic + vectors = [] + if isinstance(db.databackend, MongoDataBackend): + vectors = [ + { + 'vector': MongoStyleDict(doc)[ + f'_outputs.{vi.indexing_listener.predict_id}' + ], + 'id': str(doc['_id']), + } + for doc in docs + ] + elif isinstance(db.databackend, IbisDataBackend): + docs = db.execute(select.outputs(vi.indexing_listener.predict_id)) + from superduperdb.backends.ibis.data_backend import INPUT_KEY - vectors = [ - { - 'vector': doc[f'_outputs.{vi.indexing_listener.predict_id}'], - 'id': str(doc[INPUT_KEY]), - } - for doc in docs - ] - for r in vectors: - if hasattr(r['vector'], 'numpy'): - r['vector'] = r['vector'].numpy() + vectors = [ + { + 'vector': doc[f'_outputs.{vi.indexing_listener.predict_id}'], + 'id': str(doc[INPUT_KEY]), + } + for doc in docs + ] + for r in vectors: + if hasattr(r['vector'], 'numpy'): + r['vector'] = r['vector'].numpy() - if vectors: - db.fast_vector_searchers[vi.identifier].add( - [VectorItem(**vector) for vector in vectors] - ) + if vectors: + db.fast_vector_searchers[vi.identifier].add( + [VectorItem(**vector) for vector in vectors] + ) diff --git a/test/integration/backends/postgres/test_pg_vector.py b/test/integration/backends/postgres/test_pg_vector.py new file mode 100644 index 000000000..032d04338 --- /dev/null +++ b/test/integration/backends/postgres/test_pg_vector.py @@ -0,0 +1,115 @@ +import random +import warnings +import tempfile +import ibis + +import lorem +import psycopg2 +import pytest + +import superduperdb as s +from superduperdb import CFG, superduper +from superduperdb.backends.ibis.data_backend import IbisDataBackend +from superduperdb.base.datalayer import Datalayer +from superduperdb.backends.sqlalchemy.metadata import SQLAlchemyMetadata +from superduperdb.backends.local.artifacts import FileSystemArtifactStore +from superduperdb.backends.ibis.query import Table +from superduperdb.base.document import Document +from superduperdb.components.listener import Listener +from superduperdb.components.model import ObjectModel +from superduperdb.components.vector_index import VectorIndex, vector +from superduperdb.components.schema import Schema +from superduperdb.backends.ibis.field_types import dtype + + +@pytest.fixture +def postgres_conn(): + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_db = f'{tmp_dir}/mydb.sqlite' + yield ibis.connect('postgres://' + str(tmp_db)), tmp_dir + +@pytest.fixture +def test_db(postgres_conn): + connection, tmp_dir = postgres_conn + yield make_ibis_db(connection, connection, tmp_dir) + + +def make_ibis_db(db_conn, metadata_conn, tmp_dir, in_memory=False): + return Datalayer( + databackend=IbisDataBackend(conn=db_conn, name='ibis', in_memory=in_memory), + metadata=SQLAlchemyMetadata(conn=metadata_conn.con, name='ibis'), + artifact_store=FileSystemArtifactStore(conn=tmp_dir, name='ibis'), + ) + + +def random_vector_model(x): + return [random.random() for _ in range(16)] + + +@pytest.fixture() +def pgvector_search_config(): + previous = s.CFG.vector_search + s.CFG.vector_search = s.CFG.data_backend + yield + s.CFG.vector_search = previous + + +@pytest.mark.skipif(DO_SKIP, reason='Only pgvector deployments relevant.') +def test_setup_pgvector_vector_search(pgvector_search_config): + model = ObjectModel( + identifier='test-model', object=random_vector_model, encoder=vector(shape=(16,)) + ) + db = superduper() + schema = Schema( + identifier='docs-schema', + fields={ + 'text': dtype('str', schema=schema), + }, + ) + table = Table('docs', schema=schema) + + vector_indexes = db.vector_indices + + assert not vector_indexes + + db.execute( + table.insert_many( + [Document({'text': lorem.sentence()}) for _ in range(50)] + ) + ) + db.add( + VectorIndex( + 'test-vector-index', + indexing_listener=Listener( + model=model, + key='text', + select=table.select('text'), + ), + ) + ) + + assert 'test-vector-index' in db.show('vector_index') + assert 'test-vector-index' in db.vector_indices + + +@pytest.mark.skipif(DO_SKIP, reason='Only pgvector deployments relevant.') +def test_use_pgvector_vector_search(pgvector_search_config): + db = superduper() + schema = Schema( + identifier='docs-schema', + fields={ + 'text': dtype('str', schema=schema), + }, + ) + table = Table('docs', schema=schema) + + query = table.like( + Document({'text': 'This is a test'}), n=5, vector_index='test-vector-index' + ).find() + + it = 0 + for r in db.execute(query): + print(r) + it += 1 + + assert it == 5