Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Dataframe): pull method to fetch dataset from remote server #1446

Merged
merged 10 commits into from
Dec 5, 2024
29 changes: 29 additions & 0 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
PandasAI is a wrapper around a LLM to make dataframes conversational
"""

from io import BytesIO
import os
from typing import List
from zipfile import ZipFile

import pandas as pd

from pandasai.exceptions import DatasetNotFound, PandasAIApiKeyError
from pandasai.helpers.path import find_project_root
from pandasai.helpers.request import get_pandaai_session
from .agent import Agent
from .helpers.cache import Cache
from .dataframe.base import DataFrame
Expand Down Expand Up @@ -74,6 +81,28 @@ def load(dataset_path: str, virtualized=False) -> DataFrame:
DataFrame: A new PandasAI DataFrame instance with loaded data.
"""
global _dataset_loader
dataset_full_path = os.path.join(find_project_root(), "datasets", dataset_path)
if not os.path.exists(dataset_full_path):
api_key = os.environ.get("PANDAAI_API_KEY", None)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if api_key and api_url are None before using them, and raise a PandasAIApiKeyError if they are not set. This prevents potential TypeError when constructing headers or making requests.

api_url = os.environ.get("PANDAAI_API_URL", None)
if not api_url or not api_key:
raise PandasAIApiKeyError(
"Set PANDAAI_API_URL and PANDAAI_API_KEY in environment to pull dataset from the remote server"
)

request_session = get_pandaai_session()

headers = {"accept": "application/json", "x-authorization": f"Bearer {api_key}"}

file_data = request_session.get(
"/datasets/pull", headers=headers, params={"path": dataset_path}
)
if file_data.status_code != 200:
raise DatasetNotFound("Dataset not found!")

with ZipFile(BytesIO(file_data.content)) as zip_file:
zip_file.extractall(dataset_full_path)

return _dataset_loader.load(dataset_path, virtualized)


Expand Down
29 changes: 26 additions & 3 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,25 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame:
self._cache_data(df, cache_file)

table_name = self.schema["source"]["table"]
table_description = self.schema.get("description", None)

return DataFrame(df, schema=self.schema, name=table_name, path=dataset_path)
return DataFrame(
df,
schema=self.schema,
name=table_name,
description=table_description,
path=dataset_path,
)
else:
# Initialize new dataset loader for virtualization
data_loader = self.copy()
table_name = self.schema["source"]["table"]
table_description = self.schema.get("description", None)
return VirtualDataFrame(
schema=self.schema,
data_loader=data_loader,
name=table_name,
description=table_description,
path=dataset_path,
)

Expand Down Expand Up @@ -88,10 +97,24 @@ def _is_cache_valid(self, cache_file: str) -> bool:

def _read_cache(self, cache_file: str) -> DataFrame:
cache_format = self.schema["destination"]["format"]
table_name = self.schema["source"]["table"]
table_description = self.schema.get("description", None)
if cache_format == "parquet":
return DataFrame(pd.read_parquet(cache_file))
return DataFrame(
pd.read_parquet(cache_file),
schema=self.schema,
path=self.dataset_path,
name=table_name,
description=table_description,
)
elif cache_format == "csv":
return DataFrame(pd.read_csv(cache_file))
return DataFrame(
pd.read_csv(cache_file),
schema=self.schema,
path=self.dataset_path,
name=table_name,
description=table_description,
)
else:
raise ValueError(f"Unsupported cache format: {cache_format}")

Expand Down
57 changes: 49 additions & 8 deletions pandasai/dataframe/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations
from io import BytesIO
import os
import re
from zipfile import ZipFile
import pandas as pd
from typing import TYPE_CHECKING, List, Optional, Union, Dict, ClassVar

Expand All @@ -9,13 +11,13 @@

from pandasai.config import Config
import hashlib
from pandasai.exceptions import PandasAIApiKeyError
from pandasai.exceptions import DatasetNotFound, PandasAIApiKeyError
from pandasai.helpers.dataframe_serializer import (
DataframeSerializer,
DataframeSerializerType,
)
from pandasai.helpers.path import find_project_root
from pandasai.helpers.request import Session
from pandasai.helpers.request import get_pandaai_session


if TYPE_CHECKING:
Expand Down Expand Up @@ -220,14 +222,9 @@ def save(
print(f"Dataset saved successfully to path: {dataset_directory}")

def push(self):
api_url = os.environ.get("PANDAAI_API_URL", None)
api_key = os.environ.get("PANDAAI_API_KEY", None)
if not api_url or not api_key:
raise PandasAIApiKeyError(
"Set PANDAAI_API_URL and PANDAAI_API_KEY in environment to push dataset to the remote server"
)

request_session = Session(endpoint_url=api_url, api_key=api_key)
request_session = get_pandaai_session()

params = {
"path": self.path,
Expand Down Expand Up @@ -255,3 +252,47 @@ def push(self):
params=params,
headers=headers,
)

def pull(self):
api_key = os.environ.get("PANDAAI_API_KEY", None)

if not api_key:
raise PandasAIApiKeyError(
"Set PANDAAI_API_URL and PANDAAI_API_KEY in environment to pull dataset to the remote server"
)

request_session = get_pandaai_session()

headers = {"accept": "application/json", "x-authorization": f"Bearer {api_key}"}

file_data = request_session.get(
"/datasets/pull", headers=headers, params={"path": self.path}
)
if file_data.status_code != 200:
raise DatasetNotFound("Remote dataset not found to pull!")

with ZipFile(BytesIO(file_data.content)) as zip_file:
for file_name in zip_file.namelist():
target_path = os.path.join(
find_project_root(), "datasets", self.path, file_name
)

# Check if the file already exists
if os.path.exists(target_path):
print(f"Replacing existing file: {target_path}")

# Ensure target directory exists
os.makedirs(os.path.dirname(target_path), exist_ok=True)

# Extract the file
with open(target_path, "wb") as f:
f.write(zip_file.read(file_name))

# reloads the Dataframe
from pandasai import DatasetLoader

dataset_loader = DatasetLoader()
df = dataset_loader.load(self.path, virtualized=not isinstance(self, DataFrame))
self.__init__(
df, schema=df.schema, name=df.name, description=df.description, path=df.path
)
13 changes: 11 additions & 2 deletions pandasai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,9 @@ class PandasAIApiKeyError(Exception):
Exception (Exception): PandasAIApiKeyError
"""

def __init__(self):
message = PANDASBI_SETUP_MESSAGE
def __init__(self, message: str = None):
if not message:
message = PANDASBI_SETUP_MESSAGE
super().__init__(message)


Expand Down Expand Up @@ -264,3 +265,11 @@ class MaliciousCodeGenerated(Exception):
Args:
Exception (Exception): MaliciousCodeGenerated
"""


class DatasetNotFound(Exception):
"""
Raise error if dataset not found
Args:
Exception (Exception): DatasetNotFound
"""
20 changes: 18 additions & 2 deletions pandasai/helpers/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
self._logger = logger or Logger()

def get(self, path=None, **kwargs):
return self.make_request("GET", path, **kwargs)["data"]
return self.make_request("GET", path, **kwargs)

def post(self, path=None, **kwargs):
return self.make_request("POST", path, **kwargs)
Expand Down Expand Up @@ -79,7 +79,12 @@ def make_request(
**kwargs,
)

data = response.json()
try:
data = response.json()
except ValueError:
if response.status_code == 200:
return response

if response.status_code not in [200, 201]:
if "message" in data:
raise PandasAIApiCallError(data["message"])
Expand All @@ -91,3 +96,14 @@ def make_request(
except requests.exceptions.RequestException as e:
self._logger.log(f"Request failed: {traceback.format_exc()}", logging.ERROR)
raise PandasAIApiCallError(f"Request failed: {e}") from e


def get_pandaai_session():
api_url = os.environ.get("PANDAAI_API_URL", None)
api_key = os.environ.get("PANDAAI_API_KEY", None)
if not api_url or not api_key:
raise PandasAIApiKeyError(
"Set PANDAAI_API_URL and PANDAAI_API_KEY in environment to push/pull dataset to the remote server"
)

return Session(endpoint_url=api_url, api_key=api_key)
4 changes: 2 additions & 2 deletions pandasai/vectorstores/bamboo_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_relevant_qa_documents(self, question: str, k: int = None) -> List[dict]:
try:
docs = self._session.get(
"/training-data/qa/relevant-qa", params={"query": question, "count": k}
)
)["data"]
return docs["docs"]
except Exception:
self._logger.log("Querying without using training data.", logging.ERROR)
Expand All @@ -77,7 +77,7 @@ def get_relevant_docs_documents(
docs = self._session.get(
"/training-docs/docs/relevant-docs",
params={"query": question, "count": k},
)
)["data"]
return docs["docs"]
except Exception:
self._logger.log("Querying without using training docs.", logging.ERROR)
Expand Down
Loading