Skip to content

Commit

Permalink
refactor: make all the connectors an extension
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Oct 19, 2024
1 parent bc01afe commit 1ed56f0
Show file tree
Hide file tree
Showing 66 changed files with 3,948 additions and 1,956 deletions.
30 changes: 0 additions & 30 deletions docs/connectors.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ PandasAI provides connectors for the following SQL databases:
- DataBricks
- GoogleBigQuery
- Yahoo Finance
- Airtable

Additionally, PandasAI provides a generic SQL connector that can be used to connect to any SQL database.

Expand Down Expand Up @@ -243,32 +242,3 @@ yahoo_connector = YahooFinanceConnector("MSFT")
df = SmartDataframe(yahoo_connector)
df.chat("What is the closing price for yesterday?")
```

## Airtable Connector

The Airtable connector allows you to connect to Airtable Projects Tables, by simply passing the `base_id` , `token` and `table_name` of the table you want to analyze.

To use the Airtable connector, you only need to import it into your Python code and pass it to a `Agent`,`SmartDataframe` or `SmartDatalake` object:

```python
from pandasai.connectors import AirtableConnector
from pandasai import SmartDataframe


airtable_connectors = AirtableConnector(
config={
"token": "AIRTABLE_API_TOKEN",
"table":"AIRTABLE_TABLE_NAME",
"base_id":"AIRTABLE_BASE_ID",
"where" : [
# this is optional and filters the data to
# reduce the size of the dataframe
["Status" ,"=","In progress"]
]
}
)

df = SmartDataframe(airtable_connectors)

df.chat("How many rows are there in data ?")
```
2 changes: 1 addition & 1 deletion docs/judge-agent.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ print(agent.chat("return total stars count"))

```python
from pandasai.ee.agents.judge_agent import JudgeAgent
from pandasai.llm.openai import OpenAI
from pandasai_openai import OpenAI

# can be used with all LLM's
llm = OpenAI("openai_key")
Expand Down
2 changes: 1 addition & 1 deletion examples/judge_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pandasai.agent.agent import Agent
from pandasai.ee.agents.judge_agent import JudgeAgent
from pandasai.llm.openai import OpenAI
from pandasai_openai import OpenAI

os.environ["PANDASAI_API_KEY"] = "$2a****************************"

Expand Down
2 changes: 1 addition & 1 deletion examples/security_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pandasai.agent.agent import Agent
from pandasai.ee.agents.advanced_security_agent import AdvancedSecurityAgent
from pandasai.llm.openai import OpenAI
from pandasai_openai import OpenAI

os.environ["PANDASAI_API_KEY"] = "$2a****************************"

Expand Down
4 changes: 2 additions & 2 deletions examples/table_relations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pandasai.agent.base import Agent
from pandasai.connectors.sql import PostgreSQLConnector
from pandasai_sql.sql import PostgreSQLConnector
from pandasai.ee.connectors.relations import ForeignKey, PrimaryKey
from pandasai.llm.openai import OpenAI
from pandasai_openai import OpenAI

llm = OpenAI("sk-*************")

Expand Down
11 changes: 11 additions & 0 deletions extensions/connectors/sql/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SQL Extension for PandasAI

This extension integrates SQL connectors with PandasAI, providing support for various SQL databases (mysql, postgres, cockroachdb, sqlite).

## Installation

You can install this extension using poetry:

```bash
poetry install pandasai-sql
```
38 changes: 38 additions & 0 deletions extensions/connectors/sql/pandasai_sql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from .sql import SQLConnector, SqliteConnector, SQLConnectorConfig
import importlib


def load_from_mysql(connection_info, query):
pymysql = importlib.import_module("pymysql")
pd = importlib.import_module("pandas")

conn = pymysql.connect(
host=connection_info["host"],
user=connection_info["user"],
password=connection_info["password"],
database=connection_info["database"],
port=connection_info["port"],
)
return pd.read_sql(query, conn)


def load_from_postgres(connection_info, query):
psycopg2 = importlib.import_module("psycopg2")
pd = importlib.import_module("pandas")
conn = psycopg2.connect(
host=connection_info["host"],
user=connection_info["user"],
password=connection_info["password"],
dbname=connection_info["database"],
port=connection_info["port"],
)
return pd.read_sql(query, conn)


__all__ = [
"SQLConnector",
"SqliteConnector",
"SQLConnectorConfig",
"load_from_mysql",
"load_from_postgres",
]
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from pandasai.exceptions import MaliciousQueryError
from pandasai.helpers.path import find_project_root

from ..constants import DEFAULT_FILE_PERMISSIONS
from .base import BaseConnector, BaseConnectorConfig
from pandasai.constants import DEFAULT_FILE_PERMISSIONS
from pandasai.connectors.base import BaseConnector, BaseConnectorConfig


class SQLBaseConnectorConfig(BaseConnectorConfig):
Expand Down Expand Up @@ -55,6 +55,7 @@ class SQLConnector(BaseConnector):
SQL connectors are used to connect to SQL databases in different dialects.
"""

is_sql_connector = True
_engine = None
_connection: Connection = None
_rows_count: int = None
Expand Down Expand Up @@ -654,61 +655,3 @@ def cs_table_name(self):
def execute_direct_sql_query(self, sql_query):
sql_query = sqlglot.transpile(sql_query, read="mysql", write="postgres")[0]
return super().execute_direct_sql_query(sql_query)


class OracleConnector(SQLConnector):
"""
Oracle connectors are used to connect to Oracle databases.
"""

def __init__(
self,
config: Union[SQLConnectorConfig, dict],
**kwargs,
):
"""
Initialize the Oracle connector with the given configuration.
Args:
config (ConnectorConfig): The configuration for the Oracle connector.
"""
config["dialect"] = "oracle"
config["driver"] = "cx_oracle"

if isinstance(config, dict):
oracle_env_vars = {
"host": "ORACLE_HOST",
"port": "ORACLE_PORT",
"database": "ORACLE_DATABASE",
"username": "ORACLE_USERNAME",
"password": "ORACLE_PASSWORD",
}
config = self._populate_config_from_env(config, oracle_env_vars)

super().__init__(config, **kwargs)

@cache
def head(self, n: int = 5) -> pd.DataFrame:
"""
Return the head of the data source that the connector is connected to.
This information is passed to the LLM to provide the schema of the data source.
Returns:
DataFrame: The head of the data source.
"""

if self.logger:
self.logger.log(
f"Getting head of {self.config.table} "
f"using dialect {self.config.dialect}"
)

# Run a SQL query to get all the columns names and 5 random rows
query = self._build_query(limit=n, order="dbms_random.value")

# Return the head of the data source
return pd.read_sql(query, self._connection)

@property
def cs_table_name(self):
return f'"{self.config.table}"'
Loading

0 comments on commit 1ed56f0

Please sign in to comment.