From 738692c5c175e0f423aac501c7e3222e5e51a08a Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Mon, 23 Sep 2024 14:08:57 +0200 Subject: [PATCH 1/5] moved link storage to metadata --- .../ragstack_knowledge_store/graph_store.py | 222 +++++++----------- 1 file changed, 89 insertions(+), 133 deletions(-) diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index b392f65b1..a01cb3b48 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -11,11 +11,12 @@ TYPE_CHECKING, Any, Sequence, + Set, Union, cast, ) -from cassandra.cluster import ConsistencyLevel, PreparedStatement, Session +from cassandra.cluster import ConsistencyLevel, PreparedStatement, Session, SimpleStatement from cassio.config import check_resolve_keyspace, check_resolve_session from typing_extensions import assert_never @@ -31,7 +32,7 @@ CONTENT_ID = "content_id" -CONTENT_COLUMNS = "content_id, kind, text_content, links_blob, metadata_blob" +CONTENT_COLUMNS = "content_id, text_content, metadata_blob" SELECT_CQL_TEMPLATE = ( "SELECT {columns} FROM {table_name}{where_clause}{order_clause}{limit_clause};" @@ -46,11 +47,18 @@ class Node: """Text contained by the node.""" id: str | None = None """Unique ID for the node. Will be generated by the GraphStore if not set.""" + embedding: list[float] = field(default_factory=list) + """Vector embedding of the text""" metadata: dict[str, Any] = field(default_factory=dict) """Metadata for the node.""" links: set[Link] = field(default_factory=set) - """Links for the node.""" + """All the links for the node.""" + def incoming_links(self) -> set[Link]: + return set([l for l in self.links if (l.direction in ["in", "bidir"])]) + + def outgoing_links(self) -> set[Link]: + return set([l for l in self.links if (l.direction in ["out", "bidir"])]) class SetupMode(Enum): """Mode used to create the Cassandra table.""" @@ -114,18 +122,31 @@ def _deserialize_links(json_blob: str | None) -> set[Link]: for link in cast(list[dict[str, Any]], json.loads(json_blob or "")) } +def _metadata_s_link_key(link: Link) -> str: + return "link_from_" + json.dumps({"kind": link.kind, "tag": link.tag}) + +def _metadata_s_link_value() -> str: + return "link_from" def _row_to_node(row: Any) -> Node: - metadata = _deserialize_metadata(row.metadata_blob) - links = _deserialize_links(row.links_blob) + if hasattr(row, "metadata_blob"): + metadata_blob = getattr(row, "metadata_blob") + metadata = _deserialize_metadata(metadata_blob) + links: set[Link] = _deserialize_links(metadata.get("links")) + metadata["links"] = links + else: + metadata = {} + links = set() return Node( - id=row.content_id, - text=row.text_content, + id=getattr(row, CONTENT_ID, ""), + embedding=getattr(row, "text_embedding", []), + text=getattr(row, "text_content", ""), metadata=metadata, links=links, ) + _CQL_IDENTIFIER_PATTERN = re.compile(r"[a-zA-Z][a-zA-Z0-9_]*") @@ -201,9 +222,8 @@ def __init__( self._insert_passage = session.prepare( f""" INSERT INTO {keyspace}.{node_table} ( - content_id, kind, text_content, text_embedding, link_to_tags, - link_from_tags, links_blob, metadata_blob, metadata_s - ) VALUES (?, '{Kind.passage}', ?, ?, ?, ?, ?, ?, ?) + content_id, text_content, text_embedding, metadata_blob, metadata_s + ) VALUES (?, ?, ?, ?, ?) """ # noqa: S608 ) @@ -217,7 +237,7 @@ def __init__( self._query_ids_and_link_to_tags_by_id = session.prepare( f""" - SELECT content_id, link_to_tags + SELECT content_id, metadata_blob FROM {keyspace}.{node_table} WHERE content_id = ? """ # noqa: S608 @@ -233,13 +253,9 @@ def _apply_schema(self) -> None: self._session.execute(f""" CREATE TABLE IF NOT EXISTS {self.table_name()} ( content_id TEXT, - kind TEXT, text_content TEXT, text_embedding VECTOR, - link_to_tags SET>, - link_from_tags SET>, - links_blob TEXT, metadata_blob TEXT, metadata_s MAP, @@ -254,12 +270,6 @@ def _apply_schema(self) -> None: USING 'StorageAttachedIndex'; """) - self._session.execute(f""" - CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_link_from_tags - ON {self.table_name()}(link_from_tags) - USING 'StorageAttachedIndex'; - """) - self._session.execute(f""" CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_metadata_s_index ON {self.table_name()}(ENTRIES(metadata_s)) @@ -296,31 +306,23 @@ def add_nodes( link_to_tags = set() # link to these tags link_from_tags = set() # link from these tags - for tag in links: - if tag.direction in {"in", "bidir"}: - # An incoming link should be linked *from* nodes with the given - # tag. - link_from_tags.add((tag.kind, tag.tag)) - if tag.direction in {"out", "bidir"}: - link_to_tags.add((tag.kind, tag.tag)) - metadata_s = { k: self._coerce_string(v) for k, v in metadata.items() if _is_metadata_field_indexed(k, self._metadata_indexing_policy) } + for tag in links: + if tag.direction in {"in", "bidir"}: + metadata_s[_metadata_s_link_key(link=tag)] =_metadata_s_link_value() + metadata_blob = _serialize_metadata(metadata) - links_blob = _serialize_links(links) cq.execute( self._insert_passage, parameters=( node_id, text, text_embedding, - link_to_tags, - link_from_tags, - links_blob, metadata_blob, metadata_s, ), @@ -586,18 +588,8 @@ def traversal_search( # # ... - traversal_query = self._get_search_cql( - columns="content_id, link_to_tags", - has_limit=True, - metadata_keys=list(metadata_filter.keys()), - has_embedding=True, - ) - visit_nodes_query = self._get_search_cql( - columns="content_id AS target_content_id", - has_link_from_tags=True, - metadata_keys=list(metadata_filter.keys()), - ) + with self._concurrent_queries() as cq: # Map from visited ID to depth @@ -607,7 +599,7 @@ def traversal_search( # for tags that we've already traversed. visited_tags: dict[tuple[str, str], int] = {} - def visit_nodes(d: int, nodes: Sequence[Any]) -> None: + def visit_nodes(d: int, rows: Sequence[Any]) -> None: nonlocal visited_ids nonlocal visited_tags @@ -617,32 +609,35 @@ def visit_nodes(d: int, nodes: Sequence[Any]) -> None: # Iterate over nodes, tracking the *new* outgoing kind tags for this # depth. This is tags that are either new, or newly discovered at a # lower depth. - outgoing_tags = set() - for node in nodes: - content_id = node.content_id + outgoing_tags: Set[Link] = set() + for row in rows: + content_id = row.content_id # Add visited ID. If it is closer it is a new node at this depth: if d <= visited_ids.get(content_id, depth): visited_ids[content_id] = d # If we can continue traversing from this node, - if d < depth and node.link_to_tags: + if d < depth: + node = _row_to_node(row=row) # Record any new (or newly discovered at a lower depth) # tags to the set to traverse. - for kind, value in node.link_to_tags: - if d <= visited_tags.get((kind, value), depth): + for link in node.outgoing_links(): + if d <= visited_tags.get((link.kind, link.tag), depth): # Record that we'll query this tag at the # given depth, so we don't fetch it again # (unless we find it an earlier depth) - visited_tags[(kind, value)] = d - outgoing_tags.add((kind, value)) + visited_tags[(link.kind, link.tag)] = d + outgoing_tags.add(link) if outgoing_tags: # If there are new tags to visit at the next depth, query for the # node IDs. - for kind, value in outgoing_tags: - params = self._get_search_params( - link_from_tags=(kind, value), metadata=metadata_filter + for link in outgoing_tags: + visit_nodes_query, params = self._get_search_cql_and_params( + columns="content_id AS target_content_id", + metadata=metadata_filter, + link_keys=[_metadata_s_link_key(link)] ) cq.execute( query=visit_nodes_query, @@ -668,17 +663,17 @@ def visit_targets(d: int, targets: Sequence[Any]) -> None: callback=lambda rows, d=d: visit_nodes(d + 1, rows), ) - query_embedding = self._embedding.embed_query(query) - params = self._get_search_params( + initial_query, params = self._get_search_cql_and_params( + columns="content_id, metadata_blob", limit=k, metadata=metadata_filter, - embedding=query_embedding, + embedding=self._embedding.embed_query(query), ) cq.execute( - traversal_query, + initial_query, parameters=params, - callback=lambda nodes: visit_nodes(0, nodes), + callback=lambda initial_rows: visit_nodes(0, initial_rows), ) return self._nodes_with_ids(visited_ids.keys()) @@ -848,21 +843,13 @@ def _coerce_string(value: Any) -> str: def _extract_where_clause_cql( self, - has_id: bool = False, metadata_keys: Sequence[str] = (), - has_link_from_tags: bool = False, ) -> str: wc_blocks: list[str] = [] - if has_id: - wc_blocks.append("content_id == ?") - - if has_link_from_tags: - wc_blocks.append("link_from_tags CONTAINS (?, ?)") - for key in sorted(metadata_keys): if _is_metadata_field_indexed(key, self._metadata_indexing_policy): - wc_blocks.append(f"metadata_s['{key}'] = ?") + wc_blocks.append(f"metadata_s['{key}'] = %s") else: msg = "Non-indexed metadata fields cannot be used in queries." raise ValueError(msg) @@ -875,14 +862,9 @@ def _extract_where_clause_cql( def _extract_where_clause_params( self, metadata: dict[str, Any], - link_from_tags: tuple[str, str] | None = None, ) -> list[Any]: params: list[Any] = [] - if link_from_tags is not None: - params.append(link_from_tags[0]) - params.append(link_from_tags[1]) - for key, value in sorted(metadata.items()): if _is_metadata_field_indexed(key, self._metadata_indexing_policy): params.append(self._coerce_string(value=value)) @@ -892,22 +874,28 @@ def _extract_where_clause_params( return params - def _get_search_cql( + def _get_search_cql_and_params( self, - has_limit: bool = False, - columns: str | None = CONTENT_COLUMNS, - metadata_keys: Sequence[str] = (), - has_id: bool = False, - has_embedding: bool = False, - has_link_from_tags: bool = False, - ) -> PreparedStatement: - where_clause = self._extract_where_clause_cql( - has_id=has_id, - metadata_keys=metadata_keys, - has_link_from_tags=has_link_from_tags, - ) - limit_clause = " LIMIT ?" if has_limit else "" - order_clause = " ORDER BY text_embedding ANN OF ?" if has_embedding else "" + columns: str, + limit: int | None = None, + metadata: dict[str, Any] | None = None, + embedding: list[float] | None = None, + link_keys: list[str] | None = None, + ) -> tuple[PreparedStatement|SimpleStatement, tuple[Any, ...]]: + if link_keys is not None: + if metadata is None: + metadata = {} + else: + # don't add link search to original metadata dict + metadata = metadata.copy() + for link_key in link_keys: + metadata[link_key] = _metadata_s_link_value() + + metadata_keys = list(metadata.keys()) if metadata else [] + + where_clause = self._extract_where_clause_cql(metadata_keys=metadata_keys) + limit_clause = " LIMIT ?" if limit is not None else "" + order_clause = " ORDER BY text_embedding ANN OF ?" if embedding is not None else "" select_cql = SELECT_CQL_TEMPLATE.format( columns=columns, @@ -917,50 +905,18 @@ def _get_search_cql( limit_clause=limit_clause, ) - if select_cql in self._prepared_query_cache: - return self._prepared_query_cache[select_cql] - - prepared_query = self._session.prepare(select_cql) - prepared_query.consistency_level = ConsistencyLevel.ONE - self._prepared_query_cache[select_cql] = prepared_query - - return prepared_query - - def _get_search_params( - self, - limit: int | None = None, - metadata: dict[str, Any] | None = None, - embedding: list[float] | None = None, - link_from_tags: tuple[str, str] | None = None, - ) -> tuple[PreparedStatement, tuple[Any, ...]]: - where_params = self._extract_where_clause_params( - metadata=metadata or {}, link_from_tags=link_from_tags - ) - + where_params = self._extract_where_clause_params(metadata=metadata or {}) limit_params = [limit] if limit is not None else [] order_params = [embedding] if embedding is not None else [] - return tuple(list(where_params) + order_params + limit_params) + params = tuple(list(where_params) + order_params + limit_params) - def _get_search_cql_and_params( - self, - limit: int | None = None, - columns: str | None = CONTENT_COLUMNS, - metadata: dict[str, Any] | None = None, - embedding: list[float] | None = None, - link_from_tags: tuple[str, str] | None = None, - ) -> tuple[PreparedStatement, tuple[Any, ...]]: - query = self._get_search_cql( - has_limit=limit is not None, - columns=columns, - metadata_keys=list(metadata.keys()) if metadata else (), - has_embedding=embedding is not None, - has_link_from_tags=link_from_tags is not None, - ) - params = self._get_search_params( - limit=limit, - metadata=metadata, - embedding=embedding, - link_from_tags=link_from_tags, - ) - return query, params + if len(metadata_keys) > 0: + return SimpleStatement(query_string=select_cql, fetch_size=100), params + elif select_cql in self._prepared_query_cache: + return self._prepared_query_cache[select_cql], params + else: + prepared_query = self._session.prepare(select_cql) + prepared_query.consistency_level = ConsistencyLevel.ONE + self._prepared_query_cache[select_cql] = prepared_query + return prepared_query, params \ No newline at end of file From 524ba225706b9983beaba6949f2363d6b32bc949 Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Mon, 23 Sep 2024 16:48:31 +0200 Subject: [PATCH 2/5] updated mmr traversal --- .../ragstack_knowledge_store/graph_store.py | 148 ++++++++---------- 1 file changed, 61 insertions(+), 87 deletions(-) diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index a01cb3b48..65d85d2c5 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -22,7 +22,6 @@ from ._mmr_helper import MmrHelper from .concurrency import ConcurrentQueries -from .content import Kind from .links import Link if TYPE_CHECKING: @@ -52,7 +51,7 @@ class Node: metadata: dict[str, Any] = field(default_factory=dict) """Metadata for the node.""" links: set[Link] = field(default_factory=set) - """All the links for the node.""" + """Links for the node.""" def incoming_links(self) -> set[Link]: return set([l for l in self.links if (l.direction in ["in", "bidir"])]) @@ -150,13 +149,6 @@ def _row_to_node(row: Any) -> Node: _CQL_IDENTIFIER_PATTERN = re.compile(r"[a-zA-Z][a-zA-Z0-9_]*") -@dataclass -class _Edge: - target_content_id: str - target_text_embedding: list[float] - target_link_to_tags: set[tuple[str, str]] - - class GraphStore: """A hybrid vector-and-graph store backed by Cassandra. @@ -415,60 +407,45 @@ def mmr_traversal_search( ) # For each unselected node, stores the outgoing tags. - outgoing_tags: dict[str, set[tuple[str, str]]] = {} + outgoing_links: dict[str, set[Link]] = {} + visited_links: set[Link] = set() - # Fetch the initial candidates and add them to the helper and - # outgoing_tags. - columns = "content_id, text_embedding, link_to_tags" - adjacent_query = self._get_search_cql( - has_limit=True, - columns=columns, - metadata_keys=list(metadata_filter.keys()), - has_embedding=True, - has_link_from_tags=True, - ) - - visited_tags: set[tuple[str, str]] = set() def fetch_neighborhood(neighborhood: Sequence[str]) -> None: + nonlocal outgoing_links + nonlocal visited_links + # Put the neighborhood into the outgoing tags, to avoid adding it # to the candidate set in the future. - outgoing_tags.update({content_id: set() for content_id in neighborhood}) + outgoing_links.update({content_id: set() for content_id in neighborhood}) - # Initialize the visited_tags with the set of outgoing from the + # Initialize the visited_links with the set of outgoing from the # neighborhood. This prevents re-visiting them. - visited_tags = self._get_outgoing_tags(neighborhood) + visited_links = self._get_outgoing_links(neighborhood) # Call `self._get_adjacent` to fetch the candidates. - adjacents = self._get_adjacent( - visited_tags, - adjacent_query=adjacent_query, + adjacent_nodes = self._get_adjacent( + links=visited_links, query_embedding=query_embedding, k_per_tag=adjacent_k, metadata_filter=metadata_filter, ) new_candidates = {} - for adjacent in adjacents: - if adjacent.target_content_id not in outgoing_tags: - outgoing_tags[adjacent.target_content_id] = ( - adjacent.target_link_to_tags + for adjacent_node in adjacent_nodes: + if adjacent_node.id not in outgoing_links: + outgoing_links[adjacent_node.id] = ( + adjacent_node.outgoing_links() ) - new_candidates[adjacent.target_content_id] = ( - adjacent.target_text_embedding + new_candidates[adjacent_node.id] = ( + adjacent_node.embedding ) helper.add_candidates(new_candidates) def fetch_initial_candidates() -> None: - initial_candidates_query = self._get_search_cql( - has_limit=True, - columns=columns, - metadata_keys=list(metadata_filter.keys()), - has_embedding=True, - ) - - params = self._get_search_params( + initial_candidates_query, params = self._get_search_cql_and_params( + columns = "content_id, text_embedding, metadata_blob", limit=fetch_k, metadata=metadata_filter, embedding=query_embedding, @@ -479,9 +456,10 @@ def fetch_initial_candidates() -> None: ) candidates = {} for row in fetched: - if row.content_id not in outgoing_tags: - candidates[row.content_id] = row.text_embedding - outgoing_tags[row.content_id] = set(row.link_to_tags or []) + if row.content_id not in outgoing_links: + node = _row_to_node(row=row) + candidates[node.id] = node.embedding + outgoing_links[node.id] = set(node.outgoing_links()) helper.add_candidates(candidates) if initial_roots: @@ -509,34 +487,33 @@ def fetch_initial_candidates() -> None: # those. # Find the tags linked to from the selected ID. - link_to_tags = outgoing_tags.pop(selected_id) + link_to_tags = outgoing_links.pop(selected_id) # Don't re-visit already visited tags. - link_to_tags.difference_update(visited_tags) + link_to_tags.difference_update(visited_links) # Find the nodes with incoming links from those tags. - adjacents = self._get_adjacent( - link_to_tags, - adjacent_query=adjacent_query, + adjacent_nodes = self._get_adjacent( + links=link_to_tags, query_embedding=query_embedding, k_per_tag=adjacent_k, metadata_filter=metadata_filter, ) # Record the link_to_tags as visited. - visited_tags.update(link_to_tags) + visited_links.update(link_to_tags) new_candidates = {} - for adjacent in adjacents: - if adjacent.target_content_id not in outgoing_tags: - outgoing_tags[adjacent.target_content_id] = ( - adjacent.target_link_to_tags + for adjacent_node in adjacent_nodes: + if adjacent_node.id not in outgoing_links: + outgoing_links[adjacent_node.id] = ( + adjacent_node.outgoing_links() ) - new_candidates[adjacent.target_content_id] = ( - adjacent.target_text_embedding + new_candidates[adjacent_node.id] = ( + adjacent_node.embedding ) if next_depth < depths.get( - adjacent.target_content_id, depth + 1 + adjacent_node.id, depth + 1 ): # If this is a new shortest depth, or there was no # previous depth, update the depths. This ensures that @@ -548,7 +525,7 @@ def fetch_initial_candidates() -> None: # a shorter path via nodes selected later. This is # currently "intended", but may be worth experimenting # with. - depths[adjacent.target_content_id] = next_depth + depths[adjacent_node.id] = next_depth helper.add_candidates(new_candidates) return self._nodes_with_ids(helper.selected_ids) @@ -597,11 +574,11 @@ def traversal_search( # Map from visited tag `(kind, tag)` to depth. Allows skipping queries # for tags that we've already traversed. - visited_tags: dict[tuple[str, str], int] = {} + visited_links: dict[Link, int] = {} def visit_nodes(d: int, rows: Sequence[Any]) -> None: nonlocal visited_ids - nonlocal visited_tags + nonlocal visited_links # Visit nodes at the given depth. # Each node has `content_id` and `link_to_tags`. @@ -609,7 +586,7 @@ def visit_nodes(d: int, rows: Sequence[Any]) -> None: # Iterate over nodes, tracking the *new* outgoing kind tags for this # depth. This is tags that are either new, or newly discovered at a # lower depth. - outgoing_tags: Set[Link] = set() + outgoing_links: Set[Link] = set() for row in rows: content_id = row.content_id @@ -623,17 +600,17 @@ def visit_nodes(d: int, rows: Sequence[Any]) -> None: # Record any new (or newly discovered at a lower depth) # tags to the set to traverse. for link in node.outgoing_links(): - if d <= visited_tags.get((link.kind, link.tag), depth): + if d <= visited_links.get(link, depth): # Record that we'll query this tag at the # given depth, so we don't fetch it again # (unless we find it an earlier depth) - visited_tags[(link.kind, link.tag)] = d - outgoing_tags.add(link) + visited_links[link] = d + outgoing_links.add(link) - if outgoing_tags: + if outgoing_links: # If there are new tags to visit at the next depth, query for the # node IDs. - for link in outgoing_tags: + for link in outgoing_links: visit_nodes_query, params = self._get_search_cql_and_params( columns="content_id AS target_content_id", metadata=metadata_filter, @@ -707,21 +684,21 @@ def get_node(self, content_id: str) -> Node: """Get a node by its id.""" return self._nodes_with_ids(ids=[content_id])[0] - def _get_outgoing_tags( + def _get_outgoing_links( self, source_ids: Iterable[str], - ) -> set[tuple[str, str]]: + ) -> set[Link]: """Return the set of outgoing tags for the given source ID(s). Args: source_ids: The IDs of the source nodes to retrieve outgoing tags for. """ - tags = set() + links = set() def add_sources(rows: Iterable[Any]) -> None: for row in rows: - if row.link_to_tags: - tags.update(row.link_to_tags) + node = _row_to_node(row=row) + links.update(node.outgoing_links()) with self._concurrent_queries() as cq: for source_id in source_ids: @@ -731,21 +708,19 @@ def add_sources(rows: Iterable[Any]) -> None: callback=add_sources, ) - return tags + return links def _get_adjacent( self, - tags: set[tuple[str, str]], - adjacent_query: PreparedStatement, + links: set[Link], query_embedding: list[float], k_per_tag: int | None = None, metadata_filter: dict[str, Any] | None = None, - ) -> Iterable[_Edge]: + ) -> Iterable[Node]: """Return the target nodes with incoming links from any of the given tags. Args: tags: The tags to look for links *from*. - adjacent_query: Prepared query for adjacent nodes. query_embedding: The query embedding. Used to rank target nodes. k_per_tag: The number of target nodes to fetch for each outgoing tag. metadata_filter: Optional metadata to filter the results. @@ -753,28 +728,27 @@ def _get_adjacent( Returns: List of adjacent edges. """ - targets: dict[str, _Edge] = {} + targets: dict[str, Node] = {} def add_targets(rows: Iterable[Any]) -> None: + nonlocal targets + # TODO: Figure out how to use the "kind" on the edge. # This is tricky, since we currently issue one query for anything # adjacent via any kind, and we don't have enough information to # determine which kind(s) a given target was reached from. for row in rows: if row.content_id not in targets: - targets[row.content_id] = _Edge( - target_content_id=row.content_id, - target_text_embedding=row.text_embedding, - target_link_to_tags=set(row.link_to_tags or []), - ) + targets[row.content_id] = _row_to_node(row=row) with self._concurrent_queries() as cq: - for kind, value in tags: - params = self._get_search_params( + for link in links: + adjacent_query, params = self._get_search_cql_and_params( + columns = "content_id, text_embedding, metadata_blob", limit=k_per_tag or 10, metadata=metadata_filter, embedding=query_embedding, - link_from_tags=(kind, value), + link_keys=[_metadata_s_link_key(link=link)] ) cq.execute( @@ -919,4 +893,4 @@ def _get_search_cql_and_params( prepared_query = self._session.prepare(select_cql) prepared_query.consistency_level = ConsistencyLevel.ONE self._prepared_query_cache[select_cql] = prepared_query - return prepared_query, params \ No newline at end of file + return prepared_query, params From db989964a236f2c090d03cd406ae020468119b72 Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Mon, 23 Sep 2024 17:10:40 +0200 Subject: [PATCH 3/5] minor tweaks --- .../ragstack_knowledge_store/graph_store.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index 65d85d2c5..c4a61427e 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -459,7 +459,7 @@ def fetch_initial_candidates() -> None: if row.content_id not in outgoing_links: node = _row_to_node(row=row) candidates[node.id] = node.embedding - outgoing_links[node.id] = set(node.outgoing_links()) + outgoing_links[node.id] = node.outgoing_links() helper.add_candidates(candidates) if initial_roots: @@ -610,11 +610,11 @@ def visit_nodes(d: int, rows: Sequence[Any]) -> None: if outgoing_links: # If there are new tags to visit at the next depth, query for the # node IDs. - for link in outgoing_links: + for outgoing_link in outgoing_links: visit_nodes_query, params = self._get_search_cql_and_params( columns="content_id AS target_content_id", metadata=metadata_filter, - link_keys=[_metadata_s_link_key(link)] + outgoing_link=outgoing_link, ) cq.execute( query=visit_nodes_query, @@ -748,7 +748,7 @@ def add_targets(rows: Iterable[Any]) -> None: limit=k_per_tag or 10, metadata=metadata_filter, embedding=query_embedding, - link_keys=[_metadata_s_link_key(link=link)] + outgoing_link=link, ) cq.execute( @@ -854,16 +854,15 @@ def _get_search_cql_and_params( limit: int | None = None, metadata: dict[str, Any] | None = None, embedding: list[float] | None = None, - link_keys: list[str] | None = None, + outgoing_link: Link | None = None, ) -> tuple[PreparedStatement|SimpleStatement, tuple[Any, ...]]: - if link_keys is not None: + if outgoing_link is not None: if metadata is None: metadata = {} else: # don't add link search to original metadata dict metadata = metadata.copy() - for link_key in link_keys: - metadata[link_key] = _metadata_s_link_value() + metadata[_metadata_s_link_key(link=outgoing_link)] = _metadata_s_link_value() metadata_keys = list(metadata.keys()) if metadata else [] From 1f0c23260665c6c63ccc830657a1b49c237158b3 Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Mon, 23 Sep 2024 18:24:48 +0200 Subject: [PATCH 4/5] more tweaks --- .../ragstack_knowledge_store/graph_store.py | 171 +++++++++--------- 1 file changed, 82 insertions(+), 89 deletions(-) diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index c4a61427e..55dceee2b 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -31,10 +31,8 @@ CONTENT_ID = "content_id" -CONTENT_COLUMNS = "content_id, text_content, metadata_blob" - SELECT_CQL_TEMPLATE = ( - "SELECT {columns} FROM {table_name}{where_clause}{order_clause}{limit_clause};" + "SELECT {columns} FROM {table_name}{where_clause}{order_clause}{limit_clause}" ) @@ -221,13 +219,13 @@ def __init__( self._query_by_id = session.prepare( f""" - SELECT {CONTENT_COLUMNS} + SELECT content_id, text_content, metadata_blob FROM {keyspace}.{node_table} WHERE content_id = ? """ # noqa: S608 ) - self._query_ids_and_link_to_tags_by_id = session.prepare( + self._query_id_and_metadata_by_id = session.prepare( f""" SELECT content_id, metadata_blob FROM {keyspace}.{node_table} @@ -247,10 +245,8 @@ def _apply_schema(self) -> None: content_id TEXT, text_content TEXT, text_embedding VECTOR, - metadata_blob TEXT, metadata_s MAP, - PRIMARY KEY (content_id) ) """) @@ -279,24 +275,24 @@ def add_nodes( """Add nodes to the graph store.""" node_ids: list[str] = [] texts: list[str] = [] - metadatas: list[dict[str, Any]] = [] - nodes_links: list[set[Link]] = [] + metadata_list: list[dict[str, Any]] = [] + incoming_links_list: list[set[Link]] = [] for node in nodes: if not node.id: node_ids.append(secrets.token_hex(8)) else: node_ids.append(node.id) texts.append(node.text) - metadatas.append(node.metadata) - nodes_links.append(node.links) + combined_metadata = node.metadata.copy() + combined_metadata["links"] = _serialize_links(node.links) + metadata_list.append(combined_metadata) + incoming_links_list.append(node.incoming_links()) text_embeddings = self._embedding.embed_texts(texts) with self._concurrent_queries() as cq: - tuples = zip(node_ids, texts, text_embeddings, metadatas, nodes_links) - for node_id, text, text_embedding, metadata, links in tuples: - link_to_tags = set() # link to these tags - link_from_tags = set() # link from these tags + tuples = zip(node_ids, texts, text_embeddings, metadata_list, incoming_links_list) + for node_id, text, text_embedding, metadata, incoming_links in tuples: metadata_s = { k: self._coerce_string(v) @@ -304,11 +300,11 @@ def add_nodes( if _is_metadata_field_indexed(k, self._metadata_indexing_policy) } - for tag in links: - if tag.direction in {"in", "bidir"}: - metadata_s[_metadata_s_link_key(link=tag)] =_metadata_s_link_value() + for incoming_link in incoming_links: + metadata_s[_metadata_s_link_key(link=incoming_link)] =_metadata_s_link_value() metadata_blob = _serialize_metadata(metadata) + cq.execute( self._insert_passage, parameters=( @@ -406,20 +402,20 @@ def mmr_traversal_search( score_threshold=score_threshold, ) - # For each unselected node, stores the outgoing tags. - outgoing_links: dict[str, set[Link]] = {} + # For each unselected node, stores the outgoing links. + outgoing_links_map: dict[str, set[Link]] = {} visited_links: set[Link] = set() def fetch_neighborhood(neighborhood: Sequence[str]) -> None: - nonlocal outgoing_links + nonlocal outgoing_links_map nonlocal visited_links - # Put the neighborhood into the outgoing tags, to avoid adding it + # Put the neighborhood into the outgoing links, to avoid adding it # to the candidate set in the future. - outgoing_links.update({content_id: set() for content_id in neighborhood}) + outgoing_links_map.update({content_id: set() for content_id in neighborhood}) - # Initialize the visited_links with the set of outgoing from the + # Initialize the visited_links with the set of outgoing links from the # neighborhood. This prevents re-visiting them. visited_links = self._get_outgoing_links(neighborhood) @@ -427,23 +423,21 @@ def fetch_neighborhood(neighborhood: Sequence[str]) -> None: adjacent_nodes = self._get_adjacent( links=visited_links, query_embedding=query_embedding, - k_per_tag=adjacent_k, + k_per_link=adjacent_k, metadata_filter=metadata_filter, ) - new_candidates = {} + new_candidates: dict[str, list[float]] = {} for adjacent_node in adjacent_nodes: - if adjacent_node.id not in outgoing_links: - outgoing_links[adjacent_node.id] = ( - adjacent_node.outgoing_links() - ) - - new_candidates[adjacent_node.id] = ( - adjacent_node.embedding - ) + if adjacent_node.id not in outgoing_links_map: + outgoing_links_map[adjacent_node.id] = adjacent_node.outgoing_links() + new_candidates[adjacent_node.id] = adjacent_node.embedding helper.add_candidates(new_candidates) def fetch_initial_candidates() -> None: + nonlocal outgoing_links_map + nonlocal visited_links + initial_candidates_query, params = self._get_search_cql_and_params( columns = "content_id, text_embedding, metadata_blob", limit=fetch_k, @@ -451,15 +445,15 @@ def fetch_initial_candidates() -> None: embedding=query_embedding, ) - fetched = self._session.execute( + rows = self._session.execute( query=initial_candidates_query, parameters=params ) - candidates = {} - for row in fetched: - if row.content_id not in outgoing_links: + candidates: dict[str, list[float]] = {} + for row in rows: + if row.content_id not in outgoing_links_map: node = _row_to_node(row=row) + outgoing_links_map[node.id] = node.outgoing_links() candidates[node.id] = node.embedding - outgoing_links[node.id] = node.outgoing_links() helper.add_candidates(candidates) if initial_roots: @@ -482,39 +476,33 @@ def fetch_initial_candidates() -> None: # If the next nodes would not exceed the depth limit, find the # adjacent nodes. # - # TODO: For a big performance win, we should track which tags we've + # TODO: For a big performance win, we should track which links we've # already incorporated. We don't need to issue adjacent queries for # those. - # Find the tags linked to from the selected ID. - link_to_tags = outgoing_links.pop(selected_id) + # Find the links linked to from the selected ID. + selected_outgoing_links = outgoing_links_map.pop(selected_id) - # Don't re-visit already visited tags. - link_to_tags.difference_update(visited_links) + # Don't re-visit already visited links. + selected_outgoing_links.difference_update(visited_links) - # Find the nodes with incoming links from those tags. + # Find the nodes with incoming links from those links. adjacent_nodes = self._get_adjacent( - links=link_to_tags, + links=selected_outgoing_links, query_embedding=query_embedding, - k_per_tag=adjacent_k, + k_per_link=adjacent_k, metadata_filter=metadata_filter, ) - # Record the link_to_tags as visited. - visited_links.update(link_to_tags) + # Record the selected_outgoing_links as visited. + visited_links.update(selected_outgoing_links) new_candidates = {} for adjacent_node in adjacent_nodes: - if adjacent_node.id not in outgoing_links: - outgoing_links[adjacent_node.id] = ( - adjacent_node.outgoing_links() - ) - new_candidates[adjacent_node.id] = ( - adjacent_node.embedding - ) - if next_depth < depths.get( - adjacent_node.id, depth + 1 - ): + if adjacent_node.id not in outgoing_links_map: + outgoing_links_map[adjacent_node.id] = adjacent_node.outgoing_links() + new_candidates[adjacent_node.id] = adjacent_node.embedding + if next_depth < depths.get(adjacent_node.id, depth + 1): # If this is a new shortest depth, or there was no # previous depth, update the depths. This ensures that # when we discover a node we will have the shortest @@ -556,12 +544,12 @@ def traversal_search( """ # Depth 0: # Query for `k` nodes similar to the question. - # Retrieve `content_id` and `link_to_tags`. + # Retrieve `content_id` and `outgoing_links()`. # # Depth 1: - # Query for nodes that have an incoming tag in the `link_to_tags` set. + # Query for nodes that have an incoming link in the `outgoing_links()` set. # Combine node IDs. - # Query for `link_to_tags` of those "new" node IDs. + # Query for `outgoing_links()` of those "new" node IDs. # # ... @@ -572,8 +560,8 @@ def traversal_search( # Map from visited ID to depth visited_ids: dict[str, int] = {} - # Map from visited tag `(kind, tag)` to depth. Allows skipping queries - # for tags that we've already traversed. + # Map from visited link to depth. Allows skipping queries + # for links that we've already traversed. visited_links: dict[Link, int] = {} def visit_nodes(d: int, rows: Sequence[Any]) -> None: @@ -581,10 +569,9 @@ def visit_nodes(d: int, rows: Sequence[Any]) -> None: nonlocal visited_links # Visit nodes at the given depth. - # Each node has `content_id` and `link_to_tags`. - # Iterate over nodes, tracking the *new* outgoing kind tags for this - # depth. This is tags that are either new, or newly discovered at a + # Iterate over nodes, tracking the *new* outgoing links for this + # depth. These are links that are either new, or newly discovered at a # lower depth. outgoing_links: Set[Link] = set() for row in rows: @@ -598,17 +585,17 @@ def visit_nodes(d: int, rows: Sequence[Any]) -> None: if d < depth: node = _row_to_node(row=row) # Record any new (or newly discovered at a lower depth) - # tags to the set to traverse. + # links to the set to traverse. for link in node.outgoing_links(): if d <= visited_links.get(link, depth): - # Record that we'll query this tag at the + # Record that we'll query this link at the # given depth, so we don't fetch it again # (unless we find it an earlier depth) visited_links[link] = d outgoing_links.add(link) if outgoing_links: - # If there are new tags to visit at the next depth, query for the + # If there are new links to visit at the next depth, query for the # node IDs. for outgoing_link in outgoing_links: visit_nodes_query, params = self._get_search_cql_and_params( @@ -622,20 +609,19 @@ def visit_nodes(d: int, rows: Sequence[Any]) -> None: callback=lambda rows, d=d: visit_targets(d, rows), ) - def visit_targets(d: int, targets: Sequence[Any]) -> None: + def visit_targets(d: int, rows: Sequence[Any]) -> None: nonlocal visited_ids - # target_content_id, tag=(kind,value) - new_nodes_at_next_depth = set() - for target in targets: - content_id = target.target_content_id + new_node_ids_at_next_depth = set() + for row in rows: + content_id = row.target_content_id if d < visited_ids.get(content_id, depth): - new_nodes_at_next_depth.add(content_id) + new_node_ids_at_next_depth.add(content_id) - if new_nodes_at_next_depth: - for node_id in new_nodes_at_next_depth: + if new_node_ids_at_next_depth: + for node_id in new_node_ids_at_next_depth: cq.execute( - self._query_ids_and_link_to_tags_by_id, + self._query_id_and_metadata_by_id, parameters=(node_id,), callback=lambda rows, d=d: visit_nodes(d + 1, rows), ) @@ -663,7 +649,10 @@ def similarity_search( ) -> Iterable[Node]: """Retrieve nodes similar to the given embedding, optionally filtered by metadata.""" # noqa: E501 query, params = self._get_search_cql_and_params( - embedding=embedding, limit=k, metadata=metadata_filter + columns=f"{CONTENT_ID}, text_content, metadata_blob", + embedding=embedding, + limit=k, + metadata=metadata_filter, ) for row in self._session.execute(query, params): @@ -675,7 +664,11 @@ def metadata_search( n: int = 5, ) -> Iterable[Node]: """Retrieve nodes based on their metadata.""" - query, params = self._get_search_cql_and_params(metadata=metadata, limit=n) + query, params = self._get_search_cql_and_params( + columns=f"{CONTENT_ID}, text_content, metadata_blob", + metadata=metadata, + limit=n, + ) for row in self._session.execute(query, params): yield _row_to_node(row) @@ -688,10 +681,10 @@ def _get_outgoing_links( self, source_ids: Iterable[str], ) -> set[Link]: - """Return the set of outgoing tags for the given source ID(s). + """Return the set of outgoing links for the given source ID(s). Args: - source_ids: The IDs of the source nodes to retrieve outgoing tags for. + source_ids: The IDs of the source nodes to retrieve outgoing links for. """ links = set() @@ -703,7 +696,7 @@ def add_sources(rows: Iterable[Any]) -> None: with self._concurrent_queries() as cq: for source_id in source_ids: cq.execute( - self._query_ids_and_link_to_tags_by_id, + self._query_id_and_metadata_by_id, (source_id,), callback=add_sources, ) @@ -714,15 +707,15 @@ def _get_adjacent( self, links: set[Link], query_embedding: list[float], - k_per_tag: int | None = None, + k_per_link: int | None = None, metadata_filter: dict[str, Any] | None = None, ) -> Iterable[Node]: - """Return the target nodes with incoming links from any of the given tags. + """Return the target nodes with incoming links from any of the given links. Args: - tags: The tags to look for links *from*. + links: The links to look for. query_embedding: The query embedding. Used to rank target nodes. - k_per_tag: The number of target nodes to fetch for each outgoing tag. + k_per_link: The number of target nodes to fetch for each link. metadata_filter: Optional metadata to filter the results. Returns: @@ -745,7 +738,7 @@ def add_targets(rows: Iterable[Any]) -> None: for link in links: adjacent_query, params = self._get_search_cql_and_params( columns = "content_id, text_embedding, metadata_blob", - limit=k_per_tag or 10, + limit=k_per_link or 10, metadata=metadata_filter, embedding=query_embedding, outgoing_link=link, From 228f5e6f6c9f105ba46b5102dc79ea79b118668c Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Thu, 26 Sep 2024 16:19:35 +0200 Subject: [PATCH 5/5] more updates --- .../ragstack_knowledge_store/graph_store.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index 55dceee2b..24cf8c326 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -142,7 +142,16 @@ def _row_to_node(row: Any) -> Node: links=links, ) +def _get_metadata_filter( + metadata: dict[str, Any] | None = None, + outgoing_link: Link | None = None, +) -> dict[str, Any]: + if outgoing_link is None: + return metadata + metadata_filter = {} if metadata is None else metadata.copy() + metadata_filter[_metadata_s_link_key(link=outgoing_link)] = _metadata_s_link_value() + return metadata_filter _CQL_IDENTIFIER_PATTERN = re.compile(r"[a-zA-Z][a-zA-Z0-9_]*") @@ -841,6 +850,9 @@ def _extract_where_clause_params( return params + + + def _get_search_cql_and_params( self, columns: str, @@ -849,15 +861,9 @@ def _get_search_cql_and_params( embedding: list[float] | None = None, outgoing_link: Link | None = None, ) -> tuple[PreparedStatement|SimpleStatement, tuple[Any, ...]]: - if outgoing_link is not None: - if metadata is None: - metadata = {} - else: - # don't add link search to original metadata dict - metadata = metadata.copy() - metadata[_metadata_s_link_key(link=outgoing_link)] = _metadata_s_link_value() + metadata_filter = _get_metadata_filter(metadata=metadata, outgoing_link=outgoing_link) - metadata_keys = list(metadata.keys()) if metadata else [] + metadata_keys = list(metadata_filter.keys()) if metadata else [] where_clause = self._extract_where_clause_cql(metadata_keys=metadata_keys) limit_clause = " LIMIT ?" if limit is not None else "" @@ -871,7 +877,7 @@ def _get_search_cql_and_params( limit_clause=limit_clause, ) - where_params = self._extract_where_clause_params(metadata=metadata or {}) + where_params = self._extract_where_clause_params(metadata=metadata_filter or {}) limit_params = [limit] if limit is not None else [] order_params = [embedding] if embedding is not None else [] @@ -886,3 +892,4 @@ def _get_search_cql_and_params( prepared_query.consistency_level = ConsistencyLevel.ONE self._prepared_query_cache[select_cql] = prepared_query return prepared_query, params +