-
Notifications
You must be signed in to change notification settings - Fork 348
/
Copy pathcolumns.py
175 lines (134 loc) · 3.64 KB
/
columns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "altair==5.5.0",
# "marimo",
# "matplotlib==3.10.0",
# "pandas==2.2.3",
# "polars==1.20.0",
# "scikit-learn==1.6.1",
# ]
# ///
import marimo
__generated_with = "0.10.15"
app = marimo.App(width="columns")
@app.cell(column=0)
def _():
import marimo as mo
return (mo,)
@app.cell
async def _():
import sys
if "pyodide" in sys.modules:
import micropip
await micropip.install("altair")
import altair as alt
return alt, micropip, sys
@app.cell
def _():
import sklearn
import sklearn.datasets
import sklearn.manifold
return (sklearn,)
@app.cell
def _():
import polars as pl
return (pl,)
@app.cell
def _(alt):
def scatter(df):
return (
alt.Chart(df)
.mark_circle()
.encode(
x=alt.X("x:Q").scale(domain=(-2.5, 2.5)),
y=alt.Y("y:Q").scale(domain=(-2.5, 2.5)),
color=alt.Color("digit:N"),
)
.properties(width=500, height=500)
)
return (scatter,)
@app.cell
def _(raw_digits):
def show_images(indices, max_images=10):
import matplotlib.pyplot as plt
indices = indices[:max_images]
images = raw_digits.reshape((-1, 8, 8))[indices]
fig, axes = plt.subplots(1, len(indices))
fig.set_size_inches(12.5, 1.5)
if len(indices) > 1:
for im, ax in zip(images, axes.flat):
ax.imshow(im, cmap="gray")
ax.set_yticks([])
ax.set_xticks([])
else:
axes.imshow(images[0], cmap="gray")
axes.set_yticks([])
axes.set_xticks([])
plt.tight_layout()
return fig
return (show_images,)
@app.cell(column=1, hide_code=True)
def _(mo):
mo.md("""# Embedding Visualizer""")
return
@app.cell(hide_code=True)
def _(mo):
mo.md(
"""
Here's a PCA **embedding of numerical digits**: each point represents a
digit, with similar digits close to each other. The data is from the UCI
ML handwritten digits dataset.
This notebook will automatically drill down into points you **select with
your mouse**; try it!
"""
)
return
@app.cell
def _(sklearn):
raw_digits, raw_labels = sklearn.datasets.load_digits(return_X_y=True)
return raw_digits, raw_labels
@app.cell
def _(pl, raw_digits, raw_labels, sklearn):
X_embedded = sklearn.decomposition.PCA(
n_components=2, whiten=True
).fit_transform(raw_digits)
embedding = pl.DataFrame(
{
"x": X_embedded[:, 0],
"y": X_embedded[:, 1],
"digit": raw_labels,
"index": list(range(X_embedded.shape[0])),
}
)
return X_embedded, embedding
@app.cell
def _(embedding, mo, scatter):
chart = mo.ui.altair_chart(scatter(embedding))
chart
return (chart,)
@app.cell
def _(chart, mo):
table = mo.ui.table(chart.value)
return (table,)
@app.cell
def _(chart, mo, show_images, table):
# show 10 images: either the first 10 from the selection, or the first ten
# selected in the table
mo.stop(not len(chart.value))
selected_images = (
show_images(list(chart.value["index"]))
if not len(table.value)
else show_images(list(table.value["index"]))
)
mo.md(
f"""
**Here's a preview of the images you've selected**:
{mo.as_html(selected_images)}
Here's all the data you've selected.
{table}
"""
)
return (selected_images,)
if __name__ == "__main__":
app.run()