-
-
Notifications
You must be signed in to change notification settings - Fork 232
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
399b882
commit 3b06087
Showing
9 changed files
with
350 additions
and
106 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
/* eslint-disable @typescript-eslint/no-explicit-any */ | ||
import { z, ZodObject } from "zod"; | ||
import { openai } from "./_openai"; | ||
import zodToJsonSchema from "zod-to-json-schema"; | ||
import OpenAI from "openai"; | ||
|
||
type Schemas<T extends Record<string, ZodObject<any>>> = T; | ||
|
||
export async function llmMany<T extends Record<string, ZodObject<any>>>( | ||
content: string, | ||
schemas: Schemas<T> | ||
) { | ||
try { | ||
// if the user passes a key "message" in schemas, throw an error | ||
if (schemas.message) throw new Error("Cannot use key 'message' in schemas"); | ||
|
||
const completion = await openai.chat.completions.create({ | ||
messages: [ | ||
{ | ||
role: "user", | ||
content, | ||
}, | ||
], | ||
tools: Object.entries(schemas).map(([key, schema]) => ({ | ||
type: "function", | ||
function: { | ||
name: key, | ||
parameters: zodToJsonSchema(schema), | ||
}, | ||
})), | ||
model: "gpt-3.5-turbo-1106", | ||
// model: "gpt-4-1106-preview", | ||
}); | ||
|
||
const choice = completion.choices[0]; | ||
|
||
if (!choice) throw new Error("No choices returned"); | ||
|
||
// Must return the full thing, message and multiple tool calls | ||
return simplifyChoice(choice) as SimplifiedChoice<T>; | ||
} catch (error) { | ||
console.error(error); | ||
const message = (error as Error)?.message || "Error with prompt"; | ||
throw new Error(message); | ||
} | ||
} | ||
|
||
type SimplifiedChoice<T extends Record<string, ZodObject<any>>> = { | ||
message: string; | ||
toolCalls: Array< | ||
{ | ||
[K in keyof T]: { | ||
name: K; | ||
args: z.infer<T[K]>; | ||
}; | ||
}[keyof T] | ||
>; | ||
}; | ||
|
||
function simplifyChoice(choice: OpenAI.Chat.Completions.ChatCompletion.Choice) { | ||
return { | ||
message: choice.message.content || "", | ||
toolCalls: | ||
choice.message.tool_calls?.map((toolCall) => ({ | ||
name: toolCall.function.name, | ||
// Wish this were type-safe! | ||
args: JSON.parse(toolCall.function.arguments ?? "{}"), | ||
})) || [], | ||
}; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import { VercelApiHandler } from "@vercel/node"; | ||
import { llmMany } from "../_lib/_llm"; | ||
import { z } from "zod"; | ||
|
||
const nodeSchema = z.object({ | ||
// id: z.string(), | ||
// classes: z.string(), | ||
label: z.string(), | ||
}); | ||
|
||
const edgeSchema = z.object({ | ||
from: z.string(), | ||
to: z.string(), | ||
label: z.string().optional().default(""), | ||
}); | ||
|
||
const graphSchema = z.object({ | ||
nodes: z.array(nodeSchema), | ||
edges: z.array(edgeSchema), | ||
}); | ||
|
||
const handler: VercelApiHandler = async (req, res) => { | ||
const { graph, prompt } = req.body; | ||
if (!graph || !prompt) { | ||
throw new Error("Missing graph or prompt"); | ||
} | ||
|
||
const result = await llmMany( | ||
`You are an AI flowchart assistant. Help the create a flowchart or diagram. Here is the current state of the flowchart: | ||
${JSON.stringify(graph, null, 2)} | ||
Here is the user's message: | ||
${prompt}`, | ||
{ | ||
updateGraph: graphSchema, | ||
} | ||
); | ||
|
||
res.json(result); | ||
}; | ||
|
||
export default handler; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
import { MagicWand, Microphone } from "phosphor-react"; | ||
import { Button2, IconButton2 } from "../ui/Shared"; | ||
import * as Popover from "@radix-ui/react-popover"; | ||
import { Trans, t } from "@lingui/macro"; | ||
import { useCallback, useRef, useState } from "react"; | ||
import { useDoc } from "../lib/useDoc"; | ||
import { parse, stringify, Graph as GSGraph } from "graph-selector"; | ||
import { useMutation } from "react-query"; | ||
|
||
// The Graph type we send to AI is slightly different from internal representation | ||
type GraphForAI = { | ||
nodes: { | ||
label: string; | ||
id?: string; | ||
}[]; | ||
edges: { | ||
label: string; | ||
from: string; | ||
to: string; | ||
}[]; | ||
}; | ||
|
||
export function EditWithAI() { | ||
const [isOpen, setIsOpen] = useState(false); | ||
const { mutate: edit, isLoading } = useMutation({ | ||
mutationFn: async (body: { prompt: string; graph: GraphForAI }) => { | ||
// /api/prompt/edit | ||
const response = await fetch("/api/prompt/edit", { | ||
method: "POST", | ||
body: JSON.stringify(body), | ||
headers: { | ||
"Content-Type": "application/json", | ||
}, | ||
}); | ||
const data = await response.json(); | ||
return data as { | ||
message: string; | ||
toolCalls: { | ||
name: "updateGraph"; | ||
args: GraphForAI; | ||
}[]; | ||
}; | ||
}, | ||
onMutate: () => setIsOpen(false), | ||
onSuccess(data) { | ||
if (data.message) { | ||
window.alert(data.message); | ||
} | ||
|
||
for (const { name, args } of data.toolCalls) { | ||
switch (name) { | ||
case "updateGraph": { | ||
const newText = toGraphSelector(args); | ||
useDoc.setState({ text: newText }, false, "EditWithAI"); | ||
break; | ||
} | ||
} | ||
} | ||
}, | ||
}); | ||
const handleSubmit = useCallback( | ||
(e: React.FormEvent<HTMLFormElement>) => { | ||
e.preventDefault(); | ||
|
||
const formData = new FormData(e.currentTarget); | ||
const prompt = formData.get("prompt") as string; | ||
if (!prompt) return; | ||
|
||
const text = useDoc.getState().text; | ||
const _graph = parse(text); | ||
|
||
const graph: GraphForAI = { | ||
nodes: _graph.nodes.map((node) => { | ||
if (isCustomID(node.data.id)) { | ||
return { | ||
label: node.data.label, | ||
id: node.data.id, | ||
}; | ||
} | ||
|
||
return { | ||
label: node.data.label, | ||
}; | ||
}), | ||
edges: _graph.edges.map((edge) => { | ||
// Because generated edges internally use a custom ID, | ||
// we need to find the label, unless the user is using a custom ID | ||
|
||
let from = edge.source; | ||
if (!isCustomID(from)) { | ||
// find the from node | ||
const fromNode = _graph.nodes.find((node) => node.data.id === from); | ||
if (!fromNode) throw new Error("from node not found"); | ||
from = fromNode.data.label; | ||
} | ||
|
||
let to = edge.target; | ||
if (!isCustomID(to)) { | ||
// find the to node | ||
const toNode = _graph.nodes.find((node) => node.data.id === to); | ||
if (!toNode) throw new Error("to node not found"); | ||
to = toNode.data.label; | ||
} | ||
|
||
return { | ||
label: edge.data.label, | ||
from, | ||
to, | ||
}; | ||
}), | ||
}; | ||
|
||
edit({ prompt, graph }); | ||
}, | ||
[edit] | ||
); | ||
|
||
const formRef = useRef<HTMLFormElement>(null); | ||
|
||
return ( | ||
<Popover.Root open={isOpen} onOpenChange={setIsOpen}> | ||
<Popover.Trigger asChild> | ||
<Button2 | ||
leftIcon={ | ||
<MagicWand className="group-hover-tilt-shaking" size={18} /> | ||
} | ||
color="zinc" | ||
size="sm" | ||
rounded | ||
className="aria-[expanded=true]:bg-zinc-700" | ||
isLoading={isLoading} | ||
> | ||
<Trans>Edit with AI</Trans> | ||
</Button2> | ||
</Popover.Trigger> | ||
<Popover.Portal> | ||
<Popover.Content | ||
side="top" | ||
sideOffset={10} | ||
align="center" | ||
className="w-[300px] bg-white rounded shadow border p-2" | ||
> | ||
<form className="grid gap-2" onSubmit={handleSubmit} ref={formRef}> | ||
<div className="relative"> | ||
<textarea | ||
placeholder={t`Write your prompt here or press and hold the button to speak...`} | ||
className="text-xs w-full resize-none h-24 p-2 leading-normal" | ||
name="prompt" | ||
required | ||
onKeyDown={(e) => { | ||
if (!formRef.current) return; | ||
|
||
// submit form on Enter | ||
if (e.key === "Enter" && !e.shiftKey) { | ||
e.preventDefault(); | ||
formRef.current.requestSubmit(); | ||
} | ||
}} | ||
/> | ||
<IconButton2 size="xs" className="!absolute bottom-0 right-0"> | ||
<Microphone size={16} /> | ||
</IconButton2> | ||
</div> | ||
<Button2 size="sm" color="purple"> | ||
<Trans>Submit</Trans> | ||
</Button2> | ||
</form> | ||
</Popover.Content> | ||
</Popover.Portal> | ||
</Popover.Root> | ||
); | ||
} | ||
|
||
// Match any string like "n1", "n23", "n902834" | ||
export function isCustomID(id: string) { | ||
return !id.match(/^n\d+$/); | ||
} | ||
|
||
function toGraphSelector(graph: GraphForAI) { | ||
const g: GSGraph = { | ||
nodes: graph.nodes.map((node) => ({ | ||
data: { | ||
id: node.label, | ||
label: node.label, | ||
classes: "", | ||
}, | ||
})), | ||
edges: graph.edges.map((edge) => ({ | ||
source: edge.from, | ||
target: edge.to, | ||
data: { | ||
id: "", | ||
label: edge.label ?? "", | ||
classes: "", | ||
}, | ||
})), | ||
}; | ||
|
||
return stringify(g, { compact: true }); | ||
} |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.