Skip to content

Commit

Permalink
changed test to try dse and astra
Browse files Browse the repository at this point in the history
  • Loading branch information
epinzur committed Feb 12, 2024
1 parent e2143f6 commit b22530e
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 114 deletions.
87 changes: 25 additions & 62 deletions ragstack-e2e-tests/e2e_tests/llama_index/conftest.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,39 @@
import logging
from typing import List

import pytest

from llama_index.embeddings import OpenAIEmbedding
from llama_index.llms import OpenAI

from e2e_tests.conftest import (
get_required_env,
is_astra,
get_vector_store_handler,
)
from llama_index import (
ServiceContext,
StorageContext,
)
from llama_index.embeddings import BaseEmbedding
from llama_index.llms import OpenAI, LLM
from llama_index.node_parser import SimpleNodeParser
from llama_index.vector_stores import AstraDBVectorStore

from e2e_tests.test_utils import skip_test_due_to_implementation_not_supported
from e2e_tests.test_utils.astradb_vector_store_handler import AstraDBVectorStoreHandler
from e2e_tests.test_utils.vector_store_handler import VectorStoreImplementation

from e2e_tests.test_utils.vector_store_handler import (
VectorStoreImplementation,
)

class Environment:
def __init__(
self, vectorstore: AstraDBVectorStore, llm: LLM, embedding: BaseEmbedding
):
self.vectorstore = vectorstore
self.llm = llm
self.embedding = embedding
self.service_context = ServiceContext.from_defaults(
embed_model=self.embedding, llm=self.llm
)
basic_node_parser = SimpleNodeParser.from_defaults(
chunk_size=100000000, include_prev_next_rel=False, include_metadata=True
)
self.service_context_no_splitting = ServiceContext.from_defaults(
embed_model=self.embedding,
llm=self.llm,
transformations=[basic_node_parser],
)
self.storage_context = StorageContext.from_defaults(vector_store=vectorstore)

@pytest.fixture
def openai_llm():
return "openai", OpenAI(api_key=get_required_env("OPEN_AI_KEY"))

@pytest.fixture(scope="package")
def environment() -> Environment:
if not is_astra:
skip_test_due_to_implementation_not_supported("astradb")
embeddings = MockEmbeddings()
handler = AstraDBVectorStoreHandler(VectorStoreImplementation.ASTRADB)
vector_db = handler.before_test().new_llamaindex_vector_store(embedding_dimension=3)
llm = OpenAI(
api_key=get_required_env("OPEN_AI_KEY"),
model="gpt-3.5-turbo-16k",
streaming=False,
temperature=0,
)
yield Environment(vectorstore=vector_db, llm=llm, embedding=embeddings)
handler.after_test()

@pytest.fixture
def openai_embedding():
return "openai", 1536, OpenAIEmbedding(api_key=get_required_env("OPEN_AI_KEY"))

class MockEmbeddings(BaseEmbedding):
def _get_query_embedding(self, query: str) -> List[float]:
return self.mock_embedding(query)

async def _aget_query_embedding(self, query: str) -> List[float]:
return self.mock_embedding(query)
@pytest.fixture
def astra_db():
handler = get_vector_store_handler(VectorStoreImplementation.ASTRADB)
context = handler.before_test()
yield context
handler.after_test()

def _get_text_embedding(self, text: str) -> List[float]:
return self.mock_embedding(text)

@staticmethod
def mock_embedding(text: str):
res = [len(text) / 2, len(text) / 5, len(text) / 10]
logging.debug("mock_embedding for " + text + " : " + str(res))
return res
@pytest.fixture
def cassandra():
handler = get_vector_store_handler(VectorStoreImplementation.CASSANDRA)
context = handler.before_test()
yield context
handler.after_test()
71 changes: 70 additions & 1 deletion ragstack-e2e-tests/e2e_tests/llama_index/test_astra.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,53 @@
import logging
from typing import List

import pytest
from httpx import ConnectError, HTTPStatusError

from e2e_tests.conftest import (
get_required_env,
is_astra,
)
from llama_index import (
ServiceContext,
StorageContext,
VectorStoreIndex,
Document,
)
from llama_index.embeddings import BaseEmbedding
from llama_index.llms import OpenAI, LLM
from llama_index.node_parser import SimpleNodeParser
from llama_index.schema import NodeWithScore
from llama_index.vector_stores import (
AstraDBVectorStore,
MetadataFilters,
ExactMatchFilter,
)

from e2e_tests.llama_index.conftest import Environment
from e2e_tests.test_utils import skip_test_due_to_implementation_not_supported
from e2e_tests.test_utils.astradb_vector_store_handler import AstraDBVectorStoreHandler
from e2e_tests.test_utils.vector_store_handler import VectorStoreImplementation


class Environment:
def __init__(
self, vectorstore: AstraDBVectorStore, llm: LLM, embedding: BaseEmbedding
):
self.vectorstore = vectorstore
self.llm = llm
self.embedding = embedding
self.service_context = ServiceContext.from_defaults(
embed_model=self.embedding, llm=self.llm
)
basic_node_parser = SimpleNodeParser.from_defaults(
chunk_size=100000000, include_prev_next_rel=False, include_metadata=True
)
self.service_context_no_splitting = ServiceContext.from_defaults(
embed_model=self.embedding,
llm=self.llm,
transformations=[basic_node_parser],
)
self.storage_context = StorageContext.from_defaults(vector_store=vectorstore)


def test_basic_vector_search(environment: Environment):
Expand Down Expand Up @@ -184,3 +219,37 @@ def test_vector_search_with_metadata(environment: Environment):
# commenting this part, as the delete is not working, maybe it is a problem with document ids ?
# documents = index.as_retriever().retrieve("RAGStack")
# assert len(documents) == 0


@pytest.fixture
def environment() -> Environment:
if not is_astra:
skip_test_due_to_implementation_not_supported("astradb")
embeddings = MockEmbeddings()
handler = AstraDBVectorStoreHandler(VectorStoreImplementation.ASTRADB)
vector_db = handler.before_test().new_llamaindex_vector_store(embedding_dimension=3)
llm = OpenAI(
api_key=get_required_env("OPEN_AI_KEY"),
model="gpt-3.5-turbo-16k",
streaming=False,
temperature=0,
)
yield Environment(vectorstore=vector_db, llm=llm, embedding=embeddings)
handler.after_test()


class MockEmbeddings(BaseEmbedding):
def _get_query_embedding(self, query: str) -> List[float]:
return self.mock_embedding(query)

async def _aget_query_embedding(self, query: str) -> List[float]:
return self.mock_embedding(query)

def _get_text_embedding(self, text: str) -> List[float]:
return self.mock_embedding(text)

@staticmethod
def mock_embedding(text: str):
res = [len(text) / 2, len(text) / 5, len(text) / 10]
logging.debug("mock_embedding for " + text + " : " + str(res))
return res
30 changes: 0 additions & 30 deletions ragstack-e2e-tests/e2e_tests/llama_index/test_compatibility_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
Document,
)
from llama_index.embeddings import (
OpenAIEmbedding,
AzureOpenAIEmbedding,
BedrockEmbedding,
)
from llama_index.llms import (
OpenAI,
AzureOpenAI,
Vertex,
Bedrock,
Expand All @@ -28,43 +26,15 @@
from e2e_tests.conftest import (
set_current_test_info,
get_required_env,
get_vector_store_handler,
)
from vertexai.vision_models import MultiModalEmbeddingModel, Image

from e2e_tests.test_utils import get_local_resource_path
from e2e_tests.test_utils.vector_store_handler import (
VectorStoreImplementation,
VectorStoreTestContext,
)


@pytest.fixture
def astra_db():
handler = get_vector_store_handler(VectorStoreImplementation.ASTRADB)
context = handler.before_test()
yield context
handler.after_test()


@pytest.fixture
def cassandra():
handler = get_vector_store_handler(VectorStoreImplementation.CASSANDRA)
context = handler.before_test()
yield context
handler.after_test()


@pytest.fixture
def openai_llm():
return "openai", OpenAI(api_key=get_required_env("OPEN_AI_KEY"))


@pytest.fixture
def openai_embedding():
return "openai", 1536, OpenAIEmbedding(api_key=get_required_env("OPEN_AI_KEY"))


@pytest.fixture
def azure_openai_llm():
return "azure-openai", AzureOpenAI(
Expand Down
61 changes: 40 additions & 21 deletions ragstack-e2e-tests/e2e_tests/llama_index/test_llama_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,60 @@
except ImportError:
pytest.skip("llama_parse is not supported, skipping tests", allow_module_level=True)

from llama_index import VectorStoreIndex
from llama_index import (
VectorStoreIndex,
StorageContext,
ServiceContext,
)

from e2e_tests.llama_index.conftest import Environment
from e2e_tests.llama_index.conftest import (
openai_llm,
openai_embedding,
)
from e2e_tests.conftest import set_current_test_info
from e2e_tests.test_utils import get_local_resource_path
from e2e_tests.test_utils.vector_store_handler import (
VectorStoreTestContext,
)


def test_llamaparse_as_text_with_vector_search(environment: Environment):
print("test_llamaparse_with_vector_search")
@pytest.fixture
def llama_parse_text():
return "text", LlamaParse(result_type="text")

file_path = get_local_resource_path("tree.pdf")
documents = LlamaParse(result_type="test").load_data(file_path)

index = VectorStoreIndex.from_documents(
documents,
storage_context=environment.storage_context,
service_context=environment.service_context,
)
@pytest.fixture
def llama_parse_markdown():
return "markdown", LlamaParse(result_type="markdown")

# Verify that the document is in the vector store
retriever = index.as_retriever()
assert len(retriever.retrieve("What was Eldenroot?")) > 0

@pytest.mark.parametrize("vector_store", ["cassandra", "astra_db"])
@pytest.mark.parametrize(
"llama_parse_instance",
["llama_parse_text", "llama_parse_markdown"],
)
def test_llama_parse(vector_store, llama_parse_instance, request):
vector_store_context: VectorStoreTestContext = request.getfixturevalue(vector_store)
lp_type, lp = request.getfixturevalue(llama_parse_instance)
_, llm = openai_llm()
_, embedding_dimensions, embedding = openai_embedding()

def test_llamaparse_as_markdown_with_vector_search(environment: Environment):
print("test_llamaparse_with_vector_search")
set_current_test_info(
"llama_index::llama_parse",
f"{lp_type},{vector_store}",
)
vector_store = vector_store_context.new_llamaindex_vector_store(
embedding_dimension=embedding_dimensions
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embedding)

file_path = get_local_resource_path("tree.pdf")
documents = LlamaParse(result_type="markdown").load_data(file_path)
documents = lp.load_data(file_path)

index = VectorStoreIndex.from_documents(
documents,
storage_context=environment.storage_context,
service_context=environment.service_context,
documents, storage_context=storage_context, service_context=service_context
)

# Verify that the document is in the vector store
retriever = index.as_retriever()
assert len(retriever.retrieve("What was Eldenroot?")) > 0

0 comments on commit b22530e

Please sign in to comment.