Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Postgres : pgvector implemenation #1926

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Rename `_Predictor` to `Model`
- Allow developers to write `Listeners` and `Graph` in a single formalism
- Change unittesting framework to pure configuration (no patching configs)
- Adding `PostgresDataBackend` for `Pgvector` integration

#### Bug Fixes
- Fixed a bug in refresh_after_insert for listeners with select None
Expand Down
5 changes: 5 additions & 0 deletions superduperdb/backends/base/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pymongo import MongoClient

from superduperdb.backends.ibis.data_backend import IbisDataBackend
from superduperdb.backends.postgres.data_backend import PostgresDataBackend
from superduperdb.backends.local.artifacts import FileSystemArtifactStore
from superduperdb.backends.mongodb.artifacts import MongoArtifactStore
from superduperdb.backends.mongodb.data_backend import MongoDataBackend
Expand All @@ -10,10 +11,13 @@
from superduperdb.vector_search.atlas import MongoAtlasVectorSearcher
from superduperdb.vector_search.in_memory import InMemoryVectorSearcher
from superduperdb.vector_search.lance import LanceVectorSearcher
from superduperdb.vector_search.postgres import PostgresVectorSearcher


data_backends = {
'mongodb': MongoDataBackend,
'ibis': IbisDataBackend,
'postgres' : PostgresDataBackend
}

artifact_stores = {
Expand All @@ -30,6 +34,7 @@
'lance': LanceVectorSearcher,
'in_memory': InMemoryVectorSearcher,
'mongodb+srv': MongoAtlasVectorSearcher,
'postgres': PostgresVectorSearcher
}

CONNECTIONS = {
Expand Down
106 changes: 106 additions & 0 deletions superduperdb/backends/postgres/data_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from superduperdb.backends.ibis.data_backend import IbisDataBackend


import typing as t
from warnings import warn

import ibis
import pandas
from ibis.backends.base import BaseBackend

from superduperdb.backends.ibis.db_helper import get_db_helper
from superduperdb.backends.ibis.field_types import FieldType, dtype
from superduperdb.backends.ibis.query import Table
from superduperdb.backends.local.artifacts import FileSystemArtifactStore
from superduperdb.backends.sqlalchemy.metadata import SQLAlchemyMetadata
from superduperdb.components.datatype import DataType
from superduperdb.components.schema import Schema

BASE64_PREFIX = 'base64:'
INPUT_KEY = '_input_id'




class PostgresDataBackend(IbisDataBackend):
makkarss929 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, conn: BaseBackend, name: str, in_memory: bool = False):
super().__init__(conn=conn, name=name)
self.in_memory = in_memory
self.dialect = getattr(conn, 'name', 'base')
self.db_helper = get_db_helper(self.dialect)

def url(self):
return self.conn.con.url + self.name

def build_artifact_store(self):
return FileSystemArtifactStore(conn='.superduperdb/artifacts/', name='ibis')

def build_metadata(self):
return SQLAlchemyMetadata(conn=self.conn.con, name='ibis')

def create_ibis_table(self, identifier: str, schema: Schema):
self.conn.create_table(identifier, schema=schema)

def insert(self, table_name, raw_documents):
for doc in raw_documents:
for k, v in doc.items():
doc[k] = self.db_helper.convert_data_format(v)
table_name, raw_documents = self.db_helper.process_before_insert(
table_name, raw_documents
)
if not self.in_memory:
self.conn.insert(table_name, raw_documents)
else:
self.conn.create_table(table_name, pandas.DataFrame(raw_documents))

def create_output_dest(
self, predict_id: str, datatype: t.Union[FieldType, DataType]
):
msg = (
"Model must have an encoder to create with the"
f" {type(self).__name__} backend."
)
assert datatype is not None, msg
if isinstance(datatype, FieldType):
output_type = dtype(datatype.identifier)
else:
output_type = datatype
fields = {
INPUT_KEY: dtype('string'),
'output': output_type,
}
return Table(
identifier=f'_outputs.{predict_id}',
schema=Schema(identifier=f'_schema/{predict_id}', fields=fields),
)

def create_table_and_schema(self, identifier: str, mapping: dict):
"""
Create a schema in the data-backend.
"""

try:
mapping = self.db_helper.process_schema_types(mapping)
t = self.conn.create_table(identifier, schema=ibis.schema(mapping))
except Exception as e:
if 'exists' in str(e):
warn("Table already exists, skipping...")
t = self.conn.table(identifier)
else:
raise e
return t

def drop(self, force: bool = False):
raise NotImplementedError(
"Dropping tables needs to be done in each DB natively"
)

def get_table_or_collection(self, identifier):
return self.conn.table(identifier)

def disconnect(self):
"""
Disconnect the client
"""

# TODO: implement me
15 changes: 15 additions & 0 deletions superduperdb/base/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _build_metadata(cfg, databackend: t.Optional['BaseDataBackend'] = None):

if metadata is None:
try:
print(metadata_stores)
jieguangzhou marked this conversation as resolved.
Show resolved Hide resolved
# try to connect to the data backend uri.
logging.info("Connecting to Metadata Client with URI: ", cfg.data_backend)
return _build_databackend_impl(
Expand Down Expand Up @@ -117,6 +118,19 @@ def _build_databackend_impl(uri, mapping, type: str = 'data_backend'):
name = uri.split('/')[-1]
conn = mongomock.MongoClient()
return mapping['mongodb'](conn, name)

elif uri.startswith('postgres://') or uri.startswith("postgresql://"):
name = uri.split('//')[0]
if type == 'data_backend':
ibis_conn = ibis.connect(uri)
print(mapping['postgres'])
return mapping['postgres'](ibis_conn, name)
else:
assert type == 'metadata'
from sqlalchemy import create_engine

sql_conn = create_engine(uri)
return mapping['sqlalchemy'](sql_conn, name)

elif uri.endswith('.csv'):
if type == 'metadata':
Expand All @@ -135,6 +149,7 @@ def _build_databackend_impl(uri, mapping, type: str = 'data_backend'):
name = uri.split('//')[0]
if type == 'data_backend':
ibis_conn = ibis.connect(uri)
print(mapping['ibis'])
jieguangzhou marked this conversation as resolved.
Show resolved Hide resolved
return mapping['ibis'](ibis_conn, name)
else:
assert type == 'metadata'
Expand Down
3 changes: 3 additions & 0 deletions superduperdb/base/superduper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def _auto_identify_connection_string(item: str, **kwargs) -> t.Any:
elif item.startswith('mongodb+srv://') and 'mongodb.net' in item:
kwargs['data_backend'] = item

elif item.startswith('postgres://') or item.startswith('postgresql://'):
kwargs['data_backend'] = item
Comment on lines +38 to +39
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this, the same reason

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will remove this


elif item.endswith('.csv'):
kwargs['data_backend'] = item

Expand Down
146 changes: 146 additions & 0 deletions superduperdb/vector_search/postgres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import json
import typing as t
import numpy
from pgvector.psycopg import psycopg, register_vector


from superduperdb.vector_search.base import BaseVectorSearcher, VectorItem


class PostgresVectorSearcher(BaseVectorSearcher):
"""
Implementation of a vector index using the ``pgvector`` library.
:param identifier: Unique string identifier of index
:param dimensions: Dimension of the vector embeddings in the Lance dataset
:param uri: connection string to postgres
:param h: ``torch.Tensor``
:param index: list of IDs
:param measure: measure to assess similarity
"""

def __init__(
self,
identifier: str,
dimensions: int,
conninfo: str,
h: t.Optional[numpy.ndarray] = None,
index: t.Optional[t.List[str]] = None,
measure: t.Optional[str] = None,
):
self.connection = psycopg.connect(conninfo=conninfo)
self.dimensions = dimensions
self.identifier = identifier
if measure == "l2" or not measure:
self.measure_query = "embedding <-> '%s'"
elif measure == "dot":
self.measure_query = "(embedding <#> '%s') * -1"
elif measure == "cosine":
self.measure_query = "1 - (embedding <=> '%s')"
else:
raise NotImplementedError("Unrecognized measure format")
jieguangzhou marked this conversation as resolved.
Show resolved Hide resolved
with self.connection.cursor() as cursor:
cursor.execute('CREATE EXTENSION IF NOT EXISTS vector')
cursor.execute(
'CREATE TABLE IF NOT EXISTS %s (id varchar, embedding vector(%d))'
% (self.identifier, self.dimensions)
)
register_vector(self.connection)
if h:
self._create_or_append_to_dataset(h, index)
jieguangzhou marked this conversation as resolved.
Show resolved Hide resolved


def __len__(self):
with self.connection.cursor() as curr:
length = curr.execute(
'SELECT COUNT(*) FROM %s' % self.identifier
).fetchone()[0]
return length


def _create_or_append_to_dataset(self, vectors, ids):
with self.connection.cursor().copy(
'COPY %s (id, embedding) FROM STDIN WITH (FORMAT BINARY)' % self.identifier
) as copy:
copy.set_types(['varchar', 'vector'])
for id_vector, vector in zip(ids, vectors):
copy.write_row([id_vector, vector])
self.connection.commit()


def add(self, items: t.Sequence[VectorItem]) -> None:
"""
Add items to the index.
:param items: t.Sequence of VectorItems
"""
ids = [item.id for item in items]
vectors = [item.vector for item in items]
self._create_or_append_to_dataset(vectors, ids)


def delete(self, ids: t.Sequence[str]) -> None:
"""
Remove items from the index
:param ids: t.Sequence of ids of vectors.
"""
with self.connection.cursor() as curr:
for id_vector in ids:
curr.execute(
"DELETE FROM %s WHERE id = '%s'" % (self.identifier, id_vector)
)
self.connection.commit()


def find_nearest_from_id(
self,
_id,
n: int = 100,
within_ids: t.Sequence[str] = (),
) -> t.Tuple[t.List[str], t.List[float]]:
"""
Find the nearest vectors to the vector with the given id.
:param _id: id of the vector
:param n: number of nearest vectors to return
"""
with self.connection.cursor() as curr:
curr.execute(
"""
SELECT embedding
FROM %s
WHERE id = '%s'"""
% (self.identifier, _id)
)
h = curr.fetchone()[0]
return self.find_nearest_from_array(h, n, within_ids)

def find_nearest_from_array(
self,
h: numpy.typing.ArrayLike,
n: int = 100,
within_ids: t.Sequence[str] = (),
) -> t.Tuple[t.List[str], t.List[float]]:
"""
Find the nearest vectors to the given vector.
:param h: vector
:param n: number of nearest vectors to return
"""
h = self.to_numpy(h)[None, :]
if len(within_ids) == 0:
condition = "1=1"
else:
within_ids_str = ', '.join([f"'{i}'" for i in within_ids])
condition = f"id in ({within_ids_str})"
query_search_nearest = f"""
SELECT id, {self.measure_query} as distance
FROM %s
WHERE %s
ORDER BY distance
LIMIT %d
"""
with self.connection.cursor() as curr:
curr.execute(
query_search_nearest % (json.dumps(h), self.identifier, condition, n)
)
nearest_items = curr.fetchall()
ids = [row[0] for row in nearest_items]
scores = [row[1] for row in nearest_items]
return ids, scores