-
Notifications
You must be signed in to change notification settings - Fork 350
/
Copy pathgenerative_ui.py
95 lines (74 loc) · 2.27 KB
/
generative_ui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "ell-ai==0.0.14",
# "marimo",
# "openai==1.53.0",
# "polars==1.12.0",
# ]
# ///
import marimo
__generated_with = "0.9.14"
app = marimo.App(width="medium")
@app.cell
def __():
import polars as pl
import marimo as mo
import os
has_api_key = os.environ.get("OPENAI_API_KEY") is not None
mo.stop(
not has_api_key,
mo.md("Please set the `OPENAI_API_KEY` environment variable").callout(),
)
# Grab a dataset
df = pl.read_csv("hf://datasets/scikit-learn/Fish/Fish.csv")
return df, has_api_key, mo, os, pl
@app.cell
def __(df, mo):
import ell
@ell.tool()
def chart_data(x_encoding: str, y_encoding: str, color: str):
"""Generate an altair chart"""
import altair as alt
return (
alt.Chart(df)
.mark_circle()
.encode(x=x_encoding, y=y_encoding, color=color)
.properties(width=500)
)
@ell.tool()
def filter_dataset(sql_query: str):
"""
Filter a polars dataframe using SQL. Please only use fields from the schema.
When referring to the table in SQL, call it 'data'.
"""
filtered = df.sql(sql_query, table_name="data")
return mo.ui.table(
filtered,
label=f"```sql\n{sql_query}\n```",
selection=None,
show_column_summaries=False,
)
return chart_data, ell, filter_dataset
@app.cell
def __(chart_data, df, ell, filter_dataset, mo):
@ell.complex(model="gpt-4o", tools=[chart_data, filter_dataset])
def analyze_dataset(prompt: str) -> str:
"""You are a data scientist that can analyze a dataset"""
return f"I have a dataset with schema: {df.schema}. \n{prompt}"
def my_model(messages):
response = analyze_dataset(messages)
if response.tool_calls:
return response.tool_calls[0]()
return response.text
mo.ui.chat(
my_model,
prompts=[
"Can you chart two columns of your choosing?",
"Can you find the min, max of all numeric fields?",
"What is the sum of {{column}}?",
],
)
return analyze_dataset, my_model
if __name__ == "__main__":
app.run()