Skip to content

Commit

Permalink
Basic ai editing working
Browse files Browse the repository at this point in the history
  • Loading branch information
rob-gordon committed Jan 10, 2024
1 parent 399b882 commit 3b06087
Show file tree
Hide file tree
Showing 9 changed files with 350 additions and 106 deletions.
70 changes: 70 additions & 0 deletions api/_lib/_llm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { z, ZodObject } from "zod";
import { openai } from "./_openai";
import zodToJsonSchema from "zod-to-json-schema";
import OpenAI from "openai";

type Schemas<T extends Record<string, ZodObject<any>>> = T;

export async function llmMany<T extends Record<string, ZodObject<any>>>(
content: string,
schemas: Schemas<T>
) {
try {
// if the user passes a key "message" in schemas, throw an error
if (schemas.message) throw new Error("Cannot use key 'message' in schemas");

const completion = await openai.chat.completions.create({
messages: [
{
role: "user",
content,
},
],
tools: Object.entries(schemas).map(([key, schema]) => ({
type: "function",
function: {
name: key,
parameters: zodToJsonSchema(schema),
},
})),
model: "gpt-3.5-turbo-1106",
// model: "gpt-4-1106-preview",
});

const choice = completion.choices[0];

if (!choice) throw new Error("No choices returned");

// Must return the full thing, message and multiple tool calls
return simplifyChoice(choice) as SimplifiedChoice<T>;
} catch (error) {
console.error(error);
const message = (error as Error)?.message || "Error with prompt";
throw new Error(message);
}
}

type SimplifiedChoice<T extends Record<string, ZodObject<any>>> = {
message: string;
toolCalls: Array<
{
[K in keyof T]: {
name: K;
args: z.infer<T[K]>;
};
}[keyof T]
>;
};

function simplifyChoice(choice: OpenAI.Chat.Completions.ChatCompletion.Choice) {
return {
message: choice.message.content || "",
toolCalls:
choice.message.tool_calls?.map((toolCall) => ({
name: toolCall.function.name,
// Wish this were type-safe!
args: JSON.parse(toolCall.function.arguments ?? "{}"),
})) || [],
};
}
8 changes: 5 additions & 3 deletions api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
"ajv": "^8.12.0",
"csv-parse": "^5.3.6",
"date-fns": "^2.29.3",
"graph-selector": "^0.9.11",
"graph-selector": "^0.10.0",
"highlight.js": "^11.8.0",
"marked": "^4.1.1",
"moniker": "^0.1.2",
"notion-to-md": "^2.5.5",
"openai": "^4.10.0",
"openai": "^4.24.2",
"shared": "workspace:*",
"stripe": "^11.11.0"
"stripe": "^11.11.0",
"zod": "^3.22.4",
"zod-to-json-schema": "^3.22.3"
},
"devDependencies": {
"@swc/jest": "^0.2.24",
Expand Down
42 changes: 42 additions & 0 deletions api/prompt/edit.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { VercelApiHandler } from "@vercel/node";
import { llmMany } from "../_lib/_llm";
import { z } from "zod";

const nodeSchema = z.object({
// id: z.string(),
// classes: z.string(),
label: z.string(),
});

const edgeSchema = z.object({
from: z.string(),
to: z.string(),
label: z.string().optional().default(""),
});

const graphSchema = z.object({
nodes: z.array(nodeSchema),
edges: z.array(edgeSchema),
});

const handler: VercelApiHandler = async (req, res) => {
const { graph, prompt } = req.body;
if (!graph || !prompt) {
throw new Error("Missing graph or prompt");
}

const result = await llmMany(
`You are an AI flowchart assistant. Help the create a flowchart or diagram. Here is the current state of the flowchart:
${JSON.stringify(graph, null, 2)}
Here is the user's message:
${prompt}`,
{
updateGraph: graphSchema,
}
);

res.json(result);
};

export default handler;
2 changes: 1 addition & 1 deletion app/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
"file-saver": "^2.0.5",
"formulaic": "workspace:*",
"framer-motion": "^10.13.1",
"graph-selector": "^0.9.12",
"graph-selector": "^0.10.0",
"gray-matter": "^4.0.2",
"highlight.js": "^11.7.0",
"immer": "^9.0.16",
Expand Down
200 changes: 200 additions & 0 deletions app/src/components/EditWithAI.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import { MagicWand, Microphone } from "phosphor-react";
import { Button2, IconButton2 } from "../ui/Shared";
import * as Popover from "@radix-ui/react-popover";
import { Trans, t } from "@lingui/macro";
import { useCallback, useRef, useState } from "react";
import { useDoc } from "../lib/useDoc";
import { parse, stringify, Graph as GSGraph } from "graph-selector";
import { useMutation } from "react-query";

// The Graph type we send to AI is slightly different from internal representation
type GraphForAI = {
nodes: {
label: string;
id?: string;
}[];
edges: {
label: string;
from: string;
to: string;
}[];
};

export function EditWithAI() {
const [isOpen, setIsOpen] = useState(false);
const { mutate: edit, isLoading } = useMutation({
mutationFn: async (body: { prompt: string; graph: GraphForAI }) => {
// /api/prompt/edit
const response = await fetch("/api/prompt/edit", {
method: "POST",
body: JSON.stringify(body),
headers: {
"Content-Type": "application/json",
},
});
const data = await response.json();
return data as {
message: string;
toolCalls: {
name: "updateGraph";
args: GraphForAI;
}[];
};
},
onMutate: () => setIsOpen(false),
onSuccess(data) {
if (data.message) {
window.alert(data.message);
}

for (const { name, args } of data.toolCalls) {
switch (name) {
case "updateGraph": {
const newText = toGraphSelector(args);
useDoc.setState({ text: newText }, false, "EditWithAI");
break;
}
}
}
},
});
const handleSubmit = useCallback(
(e: React.FormEvent<HTMLFormElement>) => {
e.preventDefault();

const formData = new FormData(e.currentTarget);
const prompt = formData.get("prompt") as string;
if (!prompt) return;

const text = useDoc.getState().text;
const _graph = parse(text);

const graph: GraphForAI = {
nodes: _graph.nodes.map((node) => {
if (isCustomID(node.data.id)) {
return {
label: node.data.label,
id: node.data.id,
};
}

return {
label: node.data.label,
};
}),
edges: _graph.edges.map((edge) => {
// Because generated edges internally use a custom ID,
// we need to find the label, unless the user is using a custom ID

let from = edge.source;
if (!isCustomID(from)) {
// find the from node
const fromNode = _graph.nodes.find((node) => node.data.id === from);
if (!fromNode) throw new Error("from node not found");
from = fromNode.data.label;
}

let to = edge.target;
if (!isCustomID(to)) {
// find the to node
const toNode = _graph.nodes.find((node) => node.data.id === to);
if (!toNode) throw new Error("to node not found");
to = toNode.data.label;
}

return {
label: edge.data.label,
from,
to,
};
}),
};

edit({ prompt, graph });
},
[edit]
);

const formRef = useRef<HTMLFormElement>(null);

return (
<Popover.Root open={isOpen} onOpenChange={setIsOpen}>
<Popover.Trigger asChild>
<Button2
leftIcon={
<MagicWand className="group-hover-tilt-shaking" size={18} />
}
color="zinc"
size="sm"
rounded
className="aria-[expanded=true]:bg-zinc-700"
isLoading={isLoading}
>
<Trans>Edit with AI</Trans>
</Button2>
</Popover.Trigger>
<Popover.Portal>
<Popover.Content
side="top"
sideOffset={10}
align="center"
className="w-[300px] bg-white rounded shadow border p-2"
>
<form className="grid gap-2" onSubmit={handleSubmit} ref={formRef}>
<div className="relative">
<textarea
placeholder={t`Write your prompt here or press and hold the button to speak...`}
className="text-xs w-full resize-none h-24 p-2 leading-normal"
name="prompt"
required
onKeyDown={(e) => {
if (!formRef.current) return;

// submit form on Enter
if (e.key === "Enter" && !e.shiftKey) {
e.preventDefault();
formRef.current.requestSubmit();
}
}}
/>
<IconButton2 size="xs" className="!absolute bottom-0 right-0">
<Microphone size={16} />
</IconButton2>
</div>
<Button2 size="sm" color="purple">
<Trans>Submit</Trans>
</Button2>
</form>
</Popover.Content>
</Popover.Portal>
</Popover.Root>
);
}

// Match any string like "n1", "n23", "n902834"
export function isCustomID(id: string) {
return !id.match(/^n\d+$/);
}

function toGraphSelector(graph: GraphForAI) {
const g: GSGraph = {
nodes: graph.nodes.map((node) => ({
data: {
id: node.label,
label: node.label,
classes: "",
},
})),
edges: graph.edges.map((edge) => ({
source: edge.from,
target: edge.to,
data: {
id: "",
label: edge.label ?? "",
classes: "",
},
})),
};

return stringify(g, { compact: true });
}
45 changes: 0 additions & 45 deletions app/src/components/EditWithAIButton.tsx

This file was deleted.

2 changes: 1 addition & 1 deletion app/src/lib/graphOptions.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { t } from "@lingui/macro";
import { CSSProperties } from "react";

export const DEFAULT_GRAPH_PADDING = 6;
export const DEFAULT_GRAPH_PADDING = 20;

export interface SelectOption {
value: string;
Expand Down
Loading

0 comments on commit 3b06087

Please sign in to comment.