Skip to content

Commit

Permalink
Add e2e test for llama-index (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet authored Dec 6, 2023
1 parent c1deb3d commit b5ca84b
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 65 deletions.
23 changes: 22 additions & 1 deletion ragstack-e2e-tests/e2e_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import os
import uuid
from dataclasses import dataclass

import pytest
import os
from astrapy.db import AstraDB as LibAstraDB


def random_string():
Expand Down Expand Up @@ -57,6 +58,26 @@ def get_default_astra_ref() -> AstraRef:
return get_astra_prod_ref()


def delete_all_astra_collections(astra_ref: AstraRef):
"""
Deletes all collections.
Current AstraDB has a limit of 5 collections, meaning orphaned collections
will cause subsequent tests to fail if the limit is reached.
"""
raw_client = LibAstraDB(api_endpoint=astra_ref.api_endpoint, token=astra_ref.token)
collections = raw_client.get_collections().get("status").get("collections")
logging.info(f"Existing collections: {collections}")
for collection_info in collections:
logging.info(f"Deleting collection: {collection_info}")
raw_client.delete_collection(collection_info)


def delete_astra_collection(astra_ref: AstraRef) -> None:
raw_client = LibAstraDB(api_endpoint=astra_ref.api_endpoint, token=astra_ref.token)
raw_client.delete_collection(astra_ref.collection)


failed_report_lines = []
all_report_lines = []
tests_stats = {
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
import logging

import cassio
import pytest
from astrapy.db import AstraDB as LibAstraDB
from e2e_tests.conftest import (
set_current_test_info_simple_rag,
get_required_env,
get_astra_dev_ref,
get_astra_prod_ref,
delete_all_astra_collections,
delete_astra_collection,
)
from e2e_tests.chat_application import run_application
from langchain.llms.huggingface_hub import HuggingFaceHub

import pytest
import cassio

from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore
from langchain.vectorstores import AstraDB, Cassandra
from e2e_tests.langchain.chat_application import run_application
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatVertexAI, BedrockChat
from langchain.embeddings import (
OpenAIEmbeddings,
Expand All @@ -24,7 +18,11 @@
HuggingFaceInferenceAPIEmbeddings,
)
from langchain.embeddings.azure_openai import AzureOpenAIEmbeddings
from langchain.llms.huggingface_hub import HuggingFaceHub
from langchain.schema.embeddings import Embeddings
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.vectorstore import VectorStore
from langchain.vectorstores import AstraDB, Cassandra

VECTOR_ASTRADB_PROD = "astradb-prod"
VECTOR_ASTRADB_DEV = "astradb-dev"
Expand All @@ -41,69 +39,45 @@ def vector_dbs():
def init_vector_db(impl, embedding: Embeddings) -> VectorStore:
if impl == VECTOR_ASTRADB_DEV:
ref = get_astra_dev_ref()
collection = ref.collection
token = ref.token
api_endpoint = ref.api_endpoint

# Ensure collections from previous runs are cleared
delete_collections(api_endpoint, token)
delete_all_astra_collections(ref)

return AstraDB(
collection_name=collection,
collection_name=ref.collection,
embedding=embedding,
token=token,
api_endpoint=api_endpoint,
token=ref.token,
api_endpoint=ref.api_endpoint,
)
elif impl == VECTOR_ASTRADB_PROD:
ref = get_astra_prod_ref()
collection = ref.collection
token = ref.token
api_endpoint = ref.api_endpoint

# Ensure collections from previous runs are cleared
delete_collections(api_endpoint, token)
delete_all_astra_collections(ref)

return AstraDB(
collection_name=collection,
collection_name=ref.collection,
embedding=embedding,
token=token,
api_endpoint=api_endpoint,
token=ref.token,
api_endpoint=ref.api_endpoint,
)
elif impl == VECTOR_CASSANDRA:
ref = get_astra_prod_ref()
table_name = ref.collection
token = ref.token
api_endpoint = ref.api_endpoint

# Ensure collections from previous runs are cleared
delete_collections(api_endpoint, token)
delete_all_astra_collections(ref)

cassio.init(token=token, database_id=ref.id)
cassio.init(token=ref.token, database_id=ref.id)
return Cassandra(
embedding=embedding,
session=None,
keyspace="default_keyspace",
table_name=table_name,
table_name=ref.collection,
)
else:
raise Exception("Unknown vector db implementation: " + impl)


def delete_collections(api_endpoint, token):
"""
Deletes all collections.
Current AstraDB has a limit of 5 collections, meaning orphaned collections
will cause subsequent tests to fail if the limit is reached.
"""
raw_client = LibAstraDB(api_endpoint=api_endpoint, token=token)
collections = raw_client.get_collections().get("status").get("collections")
logging.info(f"Existing collections: {collections}")
for collection_info in collections:
logging.info(f"Deleting collection: {collection_info}")
raw_client.delete_collection(collection_info)


def astra_delete_collection(
api_endpoint: str, token: str, collection_name: str
) -> None:
Expand All @@ -113,25 +87,9 @@ def astra_delete_collection(

def close_vector_db(impl: str, vector_store: VectorStore):
if impl == VECTOR_ASTRADB_DEV:
ref = get_astra_dev_ref()
collection = ref.collection
token = ref.token
api_endpoint = ref.api_endpoint
astra_delete_collection(
api_endpoint=api_endpoint,
token=token,
collection_name=collection,
)
delete_astra_collection(get_astra_dev_ref())
elif impl == VECTOR_ASTRADB_PROD or impl == VECTOR_CASSANDRA:
ref = get_astra_prod_ref()
collection = ref.collection
token = ref.token
api_endpoint = ref.api_endpoint
astra_delete_collection(
api_endpoint=api_endpoint,
token=token,
collection_name=collection,
)
delete_astra_collection(get_astra_prod_ref())
else:
raise Exception("Unknown vector db implementation: " + impl)

Expand Down
Empty file.
120 changes: 120 additions & 0 deletions ragstack-e2e-tests/e2e_tests/llama_index/test_llama_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from abc import ABC, abstractmethod

from e2e_tests.conftest import (
AstraRef,
get_required_env,
get_astra_dev_ref,
get_astra_prod_ref,
delete_all_astra_collections,
delete_astra_collection,
)
from llama_index import (
VectorStoreIndex,
StorageContext,
ServiceContext,
Document,
OpenAIEmbedding,
)
from llama_index.llms import OpenAI
from llama_index.vector_stores import AstraDBVectorStore


class ContextMixin(ABC):
@property
@abstractmethod
def name(self) -> str:
...

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
pass


class AstraDBVectorStoreContext(ContextMixin, ABC):
def __init__(self, astra_ref: AstraRef):
self.astra_ref = astra_ref
self.vector_store = None

def __enter__(self):
delete_all_astra_collections(self.astra_ref)

self.vector_store = AstraDBVectorStore(
collection_name=self.astra_ref.collection,
embedding_dimension=1536,
token=self.astra_ref.token,
api_endpoint=self.astra_ref.api_endpoint,
)
return self.vector_store

def __exit__(self, exc_type, exc_value, traceback):
delete_astra_collection(self.astra_ref)


class DevAstraDBVectorStoreContext(AstraDBVectorStoreContext):
name = "astradb-dev"

def __init__(self):
super().__init__(get_astra_dev_ref())


class ProdAstraDBVectorStoreContext(AstraDBVectorStoreContext):
name = "astradb-prod"

def __init__(self):
super().__init__(get_astra_prod_ref())


class OpenAILLMContext(ContextMixin):
name = "openai"

def __enter__(self):
key = get_required_env("OPEN_AI_KEY")
return OpenAI(api_key=key)


class OpenAIEmbeddingsContext(ContextMixin):
name = "openai"

def __enter__(self):
key = get_required_env("OPEN_AI_KEY")
return OpenAIEmbedding(api_key=key)


def test_openai():
_run_test(
ProdAstraDBVectorStoreContext(), OpenAILLMContext(), OpenAIEmbeddingsContext()
)


def _run_test(vector_store_ctx, llm_ctx, embed_model_ctx):
with vector_store_ctx as vector_store, llm_ctx as llm, embed_model_ctx as embed_model:
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)

documents = [
Document(
text="MyFakeProductForTesting is a versatile testing tool designed to streamline the testing process for software developers, quality assurance professionals, and product testers. It provides a comprehensive solution for testing various aspects of applications and systems, ensuring robust performance and functionality." # noqa: E501
),
Document(
text="MyFakeProductForTesting comes equipped with an advanced dynamic test scenario generator. This feature allows users to create realistic test scenarios by simulating various user interactions, system inputs, and environmental conditions. The dynamic nature of the generator ensures that tests are not only diverse but also adaptive to changes in the application under test." # noqa: E501
),
Document(
text="The product includes an intelligent bug detection and analysis module. It not only identifies bugs and issues but also provides in-depth analysis and insights into the root causes. The system utilizes machine learning algorithms to categorize and prioritize bugs, making it easier for developers and testers to address critical issues first." # noqa: E501
),
Document(
text="MyFakeProductForTesting first release happened in June 2020."
),
]

index = VectorStoreIndex.from_documents(
documents, storage_context=storage_context, service_context=service_context
)

query_engine = index.as_query_engine()
response = query_engine.query(
"When was released MyFakeProductForTesting for the first time ?"
).response
print(f"Got response ${response}")
assert "2020" in response

0 comments on commit b5ca84b

Please sign in to comment.