diff --git a/demo/inngest.ts b/demo/inngest.ts index 7e5940e..3e3da7e 100644 --- a/demo/inngest.ts +++ b/demo/inngest.ts @@ -22,7 +22,7 @@ export const inngest = new Inngest({ }); export const fn = inngest.createFunction( - { id: "agent" }, + { id: "agent", retries: 0, }, { event: "agent/run" }, async ({ event, step }) => { const model = openai({ model: "gpt-4", step }); @@ -30,9 +30,9 @@ export const fn = inngest.createFunction( // 1. Single agent // Run a single agent as a prompt without a network. - await codeWritingAgent.run(event.data.input, { - model, - }); + // await codeWritingAgent.run(event.data.input, { + // model, + // }); // 2. A network of agents that works together const network = createNetwork({ diff --git a/eslint.config.mjs b/eslint.config.mjs index 67ee505..fa104b4 100644 --- a/eslint.config.mjs +++ b/eslint.config.mjs @@ -6,7 +6,7 @@ import tseslint from "typescript-eslint"; export default tseslint.config( { - ignores: ["dist", "eslint.config.mjs", "demo"], + ignores: ["dist", "eslint.config.mjs", "demo", "examples"], }, eslint.configs.recommended, tseslint.configs.recommendedTypeChecked, diff --git a/examples/swebench/.gitignore b/examples/swebench/.gitignore new file mode 100644 index 0000000..138856a --- /dev/null +++ b/examples/swebench/.gitignore @@ -0,0 +1,2 @@ +node_modules +opt/ diff --git a/examples/swebench/Makefile b/examples/swebench/Makefile new file mode 100644 index 0000000..ecf4e73 --- /dev/null +++ b/examples/swebench/Makefile @@ -0,0 +1,4 @@ +.PHONY: init +init: + mkdir ./opt/ + wget -O ./opt/dev.parquet https://huggingface.co/datasets/princeton-nlp/SWE-bench_Lite/resolve/main/data/dev-00000-of-00001.parquet?download=true diff --git a/examples/swebench/agents/setup.ts b/examples/swebench/agents/setup.ts new file mode 100644 index 0000000..81e38b7 --- /dev/null +++ b/examples/swebench/agents/setup.ts @@ -0,0 +1,6 @@ +import { createAgent } from "../../../src"; + +createAgent({ + name: "setup", + system: "This is a system prompt", +}); diff --git a/examples/swebench/index.ts b/examples/swebench/index.ts new file mode 100644 index 0000000..8c632b7 --- /dev/null +++ b/examples/swebench/index.ts @@ -0,0 +1,23 @@ +import express from "express"; +import { serve } from "inngest/express"; +import { fn, inngest } from "./inngest"; + +const app = express(); +const port = 3001; + +// Important: ensure you add JSON middleware to process incoming JSON POST payloads. +app.use(express.json({limit: '50mb'})); + +app.use( + // Expose the middleware on our recommended path at `/api/inngest`. + "/api/inngest", + // eslint-disable-next-line @typescript-eslint/no-unsafe-argument + serve({ + client: inngest, + functions: [fn], + }), +); + +app.listen(port, () => { + console.log(`App listening on port ${port}`); +}); diff --git a/examples/swebench/inngest.ts b/examples/swebench/inngest.ts new file mode 100644 index 0000000..02cefc0 --- /dev/null +++ b/examples/swebench/inngest.ts @@ -0,0 +1,190 @@ +/* eslint-disable @typescript-eslint/no-unused-vars */ +import fs from "fs"; +import { execSync } from 'child_process'; +import { z } from "zod"; +import { + createAgent, + createNetwork, + createTool, + anthropic, + State, +} from "../../src/index"; +import { extractClassAndFns, listFilesTool, readFileTool, replaceClassMethodTool } from "./tools/tools"; +import { Inngest, EventSchemas } from "inngest"; + +export const inngest = new Inngest({ + id: "agents", + schemas: new EventSchemas().fromZod({ + "swebench/run": { + data: z.object({ + repo: z.string(), + base_commit: z.string(), + environment_setup_commit: z.string(), + problem_statement: z.string(), + }) + }, + }), +}); + +export const fn = inngest.createFunction( + { id: "agent", retries: 2, }, + { event: "swebench/run" }, + async ({ event, step }) => { + + // This is some basic stuff to initialize and set up the repos + // for the swebench test. + // + // First, we clone the repo, then we ensure we're on the correct base commit. + const dir = `./opt/${event.data.repo}`; + await step.run("clone repo", async () => { + // Check if the dir already exists. + if (fs.existsSync(dir)) { + return + } + console.log("creating repo"); + fs.mkdirSync(dir, { recursive: true }); + execSync(`cd ${dir} && git init`); + execSync(`cd ${dir} && git remote add origin git@github.com:${event.data.repo}.git`); + }); + + await step.run("check out commit", async () => { + console.log("checking out commit"); + execSync(`cd ${dir} && git fetch origin ${event.data.base_commit} --depth=1`); + execSync(`cd ${dir} && git reset --hard FETCH_HEAD`); + }); + + + const model = anthropic({ + model: "claude-3-5-haiku-latest", + max_tokens: 1000, + step: step as any, + }); + + const state = new State(); + state.kv.set("repo", event.data.repo); + + const network = createNetwork({ + agents: [planningAgent.withModel(model), editingAgent.withModel(model)], + defaultModel: model, + state, + }); + await network.run(event.data.problem_statement, (opts) => { + if (opts.network.state.kv.get("done")) { + // We're done editing. + return; + } + + if (opts.network.state.kv.get("plan") !== undefined) { + return editingAgent.withModel(model); + } + return planningAgent.withModel(model); + }); + }, +); + +// Now that the setup has been completed, we can run the agent properly within that repo. +const planningAgent = createAgent({ + name: "Planner", + description: "Plans the code to write and which files should be edited", + tools: [ + listFilesTool, + readFileTool, + extractClassAndFns, + + createTool({ + name: "create_plan", + description: "Describe a formal plan for how to fix the issue, including which files to edit and reasoning.", + parameters: z.object({ + thoughts: z.string(), + plan_details: z.string(), + edits: z.array(z.object({ + filename: z.string(), + idea: z.string(), + reasoning: z.string(), + })) + }), + + handler: async (plan, opts) => { + // Store this in the function state for introspection in tracing. + await opts.step.run("plan created", () => plan); + opts.network?.state.kv.set("plan", plan); + }, + }), + ], + + system: (network) => ` + You are an expert Python programmer working on a specific project: ${network?.state.kv.get("repo")}. + + You are given an issue reported within the project. You are planning how to fix the issue by investigating the report, + the current code, then devising a "plan" - a spec - to modify code to fix the issue. + + Your plan will be worked on and implemented after you create it. You MUST create a plan to + fix the issue. Be thorough. Think step-by-step using available tools. + + Techniques you may use to create a plan: + - Read entire files + - Find specific classes and functions within a file + `, +}) + +/** + * the editingAgent is enabled once a plan has been written. It disregards all conversation history + * and uses the plan from the current network state to construct a system prompt to edit the given + * files to resolve the input. + */ +const editingAgent = createAgent({ + name: "Editor", + description: "Edits code by replacing contents in files, or creating new files with new code.", + tools: [ + extractClassAndFns, + replaceClassMethodTool, + readFileTool, + + createTool({ + name: "done", + description: "Saves the current project and finishes editing", + handler: (_input, opts) => { + opts.network?.state.kv.delete("plan"); + opts.network?.state.kv.set("done", true); + return "Done editing"; + }, + }), + ], + lifecycle: { + + // The editing agent is only enabled once we have a plan. + enabled: (opts) => { + return opts.network?.state.kv.get("plan") !== undefined; + }, + + // onStart is called when we start inference. We want to update the history here to remove + // things from the planning agent. We update the system prompt to include details from the + // plan via network state. + onStart: ({ agent, prompt, network }) => { + + const history = (network?.state.results || []). + filter(i => i.agent === agent). // Return the current history from this agent only. + map(i => i.output.concat(i.toolCalls)). // Only add the output and tool calls to the conversation history + flat(); + + return { prompt, history, stop: false }; + }, + }, + + system: (network) => ` + You are an expert Python programmer working on a specific project: ${network?.state.kv.get("repo")}. You have been + given a plan to fix the given issue supplied by the user. + + The current plan is: + + ${JSON.stringify(network?.state.kv.get("plan"))} + + + You MUST: + - Understand the user's request + - Understand the given plan + - Write code using the tools available to fix the issue + + Once the files have been edited and you are confident in the updated code, you MUST finish your editing via calling the "done" tool. + `, +}) diff --git a/examples/swebench/package.json b/examples/swebench/package.json new file mode 100644 index 0000000..97bbbb8 --- /dev/null +++ b/examples/swebench/package.json @@ -0,0 +1,17 @@ +{ + "name": "swebench", + "version": "1.0.0", + "description": "", + "main": "index.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "keywords": [], + "author": "", + "license": "ISC", + "dependencies": { + "inngest": "^3.27.4", + "tree-sitter": "^0.22.1", + "tree-sitter-python": "^0.23.5" + } +} diff --git a/examples/swebench/pnpm-lock.yaml b/examples/swebench/pnpm-lock.yaml new file mode 100644 index 0000000..dd8408c --- /dev/null +++ b/examples/swebench/pnpm-lock.yaml @@ -0,0 +1,277 @@ +lockfileVersion: '9.0' + +settings: + autoInstallPeers: true + excludeLinksFromLockfile: false + +importers: + + .: + dependencies: + inngest: + specifier: ^3.27.4 + version: 3.27.4 + tree-sitter: + specifier: ^0.22.1 + version: 0.22.1 + tree-sitter-python: + specifier: ^0.23.5 + version: 0.23.5(tree-sitter@0.22.1) + +packages: + + '@types/debug@4.1.12': + resolution: {integrity: sha512-vIChWdVG3LG1SMxEvI/AK+FWJthlrqlTu7fbrlywTkkaONwk/UAGaULXRlf8vkzFBLVm0zkMdCquhL5aOjhXPQ==} + + '@types/ms@0.7.34': + resolution: {integrity: sha512-nG96G3Wp6acyAgJqGasjODb+acrI7KltPiRxzHPXnP3NgI28bpQDRv53olbqGXbfcgF5aiiHmO3xpwEpS5Ld9g==} + + ansi-regex@4.1.1: + resolution: {integrity: sha512-ILlv4k/3f6vfQ4OoP2AGvirOktlQ98ZEL1k9FaQjxa3L1abBgbuTDAdPOpvbGncC0BTVQrl+OM8xZGK6tWXt7g==} + engines: {node: '>=6'} + + ansi-styles@4.3.0: + resolution: {integrity: sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==} + engines: {node: '>=8'} + + canonicalize@1.0.8: + resolution: {integrity: sha512-0CNTVCLZggSh7bc5VkX5WWPWO+cyZbNd07IHIsSXLia/eAq+r836hgk+8BKoEh7949Mda87VUOitx5OddVj64A==} + + chalk@4.1.2: + resolution: {integrity: sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==} + engines: {node: '>=10'} + + color-convert@2.0.1: + resolution: {integrity: sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==} + engines: {node: '>=7.0.0'} + + color-name@1.1.4: + resolution: {integrity: sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==} + + cross-fetch@4.0.0: + resolution: {integrity: sha512-e4a5N8lVvuLgAWgnCrLr2PP0YyDOTHa9H/Rj54dirp61qXnNq46m82bRhNqIA5VccJtWBvPTFRV3TtvHUKPB1g==} + + debug@4.4.0: + resolution: {integrity: sha512-6WTZ/IxCY/T6BALoZHaE4ctp9xm+Z5kY/pzYaCHRFeyVhojxlrm+46y68HA6hr0TcwEssoxNiDEUJQjfPZ/RYA==} + engines: {node: '>=6.0'} + peerDependencies: + supports-color: '*' + peerDependenciesMeta: + supports-color: + optional: true + + has-flag@4.0.0: + resolution: {integrity: sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==} + engines: {node: '>=8'} + + hash.js@1.1.7: + resolution: {integrity: sha512-taOaskGt4z4SOANNseOviYDvjEJinIkRgmp7LbKP2YTTmVxWBl87s/uzK9r+44BclBSp2X7K1hqeNfz9JbBeXA==} + + inherits@2.0.4: + resolution: {integrity: sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==} + + inngest@3.27.4: + resolution: {integrity: sha512-S71jJNxmfA9d4jmFKSxgi/u5d+tSmpsThAmTMHhjGieBTSrfGguwhHHskkPpFCtrMWI+Qdy8ek/rtQNxFkE9eQ==} + engines: {node: '>=14'} + peerDependencies: + '@sveltejs/kit': '>=1.27.3' + '@vercel/node': '>=2.15.9' + aws-lambda: '>=1.0.7' + express: '>=4.19.2' + fastify: '>=4.21.0' + h3: '>=1.8.1' + hono: '>=4.2.7' + koa: '>=2.14.2' + next: '>=12.0.0' + typescript: '>=4.7.2' + peerDependenciesMeta: + '@sveltejs/kit': + optional: true + '@vercel/node': + optional: true + aws-lambda: + optional: true + express: + optional: true + fastify: + optional: true + h3: + optional: true + hono: + optional: true + koa: + optional: true + next: + optional: true + typescript: + optional: true + + json-stringify-safe@5.0.1: + resolution: {integrity: sha512-ZClg6AaYvamvYEE82d3Iyd3vSSIjQ+odgjaTzRuO3s7toCdFKczob2i0zCh7JE8kWn17yvAWhUVxvqGwUalsRA==} + + minimalistic-assert@1.0.1: + resolution: {integrity: sha512-UtJcAD4yEaGtjPezWuO9wC4nwUnVH/8/Im3yEHQP4b67cXlD/Qr9hdITCU1xDbSEXg2XKNaP8jsReV7vQd00/A==} + + ms@2.1.3: + resolution: {integrity: sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==} + + node-addon-api@8.3.0: + resolution: {integrity: sha512-8VOpLHFrOQlAH+qA0ZzuGRlALRA6/LVh8QJldbrC4DY0hXoMP0l4Acq8TzFC018HztWiRqyCEj2aTWY2UvnJUg==} + engines: {node: ^18 || ^20 || >= 21} + + node-fetch@2.7.0: + resolution: {integrity: sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A==} + engines: {node: 4.x || >=6.0.0} + peerDependencies: + encoding: ^0.1.0 + peerDependenciesMeta: + encoding: + optional: true + + node-gyp-build@4.8.4: + resolution: {integrity: sha512-LA4ZjwlnUblHVgq0oBF3Jl/6h/Nvs5fzBLwdEF4nuxnFdsfajde4WfxtJr3CaiH+F6ewcIB/q4jQ4UzPyid+CQ==} + hasBin: true + + serialize-error-cjs@0.1.3: + resolution: {integrity: sha512-GXwbHkufrNZ87O7DUEvWhR8eBnOqiXtHsOXakkJliG7eLDmjh6gDlbJbMZFFbUx0J5sXKgwq4NFCs41dF5MhiA==} + + strip-ansi@5.2.0: + resolution: {integrity: sha512-DuRs1gKbBqsMKIZlrffwlug8MHkcnpjs5VPmL1PAh+mA30U0DTotfDZ0d2UUsXpPmPmMMJ6W773MaA3J+lbiWA==} + engines: {node: '>=6'} + + supports-color@7.2.0: + resolution: {integrity: sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==} + engines: {node: '>=8'} + + tr46@0.0.3: + resolution: {integrity: sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==} + + tree-sitter-python@0.23.5: + resolution: {integrity: sha512-4BJo/NG9btDqdXydjY9jZuqMeqjfji3XTLnn3qr9kOmoLbon8Nc2BbAi1UbsRfU2LLjY36iRGgYqsl0CIuBwcQ==} + peerDependencies: + tree-sitter: ^0.22.1 + peerDependenciesMeta: + tree-sitter: + optional: true + + tree-sitter@0.22.1: + resolution: {integrity: sha512-gRO+jk2ljxZlIn20QRskIvpLCMtzuLl5T0BY6L9uvPYD17uUrxlxWkvYCiVqED2q2q7CVtY52Uex4WcYo2FEXw==} + + webidl-conversions@3.0.1: + resolution: {integrity: sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ==} + + whatwg-url@5.0.0: + resolution: {integrity: sha512-saE57nupxk6v3HY35+jzBwYa0rKSy0XR8JSxZPwgLr7ys0IBzhGviA1/TUGJLmSVqs8pb9AnvICXEuOHLprYTw==} + + zod@3.22.5: + resolution: {integrity: sha512-HqnGsCdVZ2xc0qWPLdO25WnseXThh0kEYKIdV5F/hTHO75hNZFp8thxSeHhiPrHZKrFTo1SOgkAj9po5bexZlw==} + +snapshots: + + '@types/debug@4.1.12': + dependencies: + '@types/ms': 0.7.34 + + '@types/ms@0.7.34': {} + + ansi-regex@4.1.1: {} + + ansi-styles@4.3.0: + dependencies: + color-convert: 2.0.1 + + canonicalize@1.0.8: {} + + chalk@4.1.2: + dependencies: + ansi-styles: 4.3.0 + supports-color: 7.2.0 + + color-convert@2.0.1: + dependencies: + color-name: 1.1.4 + + color-name@1.1.4: {} + + cross-fetch@4.0.0: + dependencies: + node-fetch: 2.7.0 + transitivePeerDependencies: + - encoding + + debug@4.4.0: + dependencies: + ms: 2.1.3 + + has-flag@4.0.0: {} + + hash.js@1.1.7: + dependencies: + inherits: 2.0.4 + minimalistic-assert: 1.0.1 + + inherits@2.0.4: {} + + inngest@3.27.4: + dependencies: + '@types/debug': 4.1.12 + canonicalize: 1.0.8 + chalk: 4.1.2 + cross-fetch: 4.0.0 + debug: 4.4.0 + hash.js: 1.1.7 + json-stringify-safe: 5.0.1 + ms: 2.1.3 + serialize-error-cjs: 0.1.3 + strip-ansi: 5.2.0 + zod: 3.22.5 + transitivePeerDependencies: + - encoding + - supports-color + + json-stringify-safe@5.0.1: {} + + minimalistic-assert@1.0.1: {} + + ms@2.1.3: {} + + node-addon-api@8.3.0: {} + + node-fetch@2.7.0: + dependencies: + whatwg-url: 5.0.0 + + node-gyp-build@4.8.4: {} + + serialize-error-cjs@0.1.3: {} + + strip-ansi@5.2.0: + dependencies: + ansi-regex: 4.1.1 + + supports-color@7.2.0: + dependencies: + has-flag: 4.0.0 + + tr46@0.0.3: {} + + tree-sitter-python@0.23.5(tree-sitter@0.22.1): + dependencies: + node-addon-api: 8.3.0 + node-gyp-build: 4.8.4 + optionalDependencies: + tree-sitter: 0.22.1 + + tree-sitter@0.22.1: + dependencies: + node-addon-api: 8.3.0 + node-gyp-build: 4.8.4 + + webidl-conversions@3.0.1: {} + + whatwg-url@5.0.0: + dependencies: + tr46: 0.0.3 + webidl-conversions: 3.0.1 + + zod@3.22.5: {} diff --git a/examples/swebench/tools/tools.ts b/examples/swebench/tools/tools.ts new file mode 100644 index 0000000..6b09c66 --- /dev/null +++ b/examples/swebench/tools/tools.ts @@ -0,0 +1,206 @@ +import fs from "fs"; +import { z } from "zod"; +import Parser from "tree-sitter"; +import Py from "tree-sitter-python"; +import { createTool } from "../../../src/index"; + +// PyClass represents a class parsed from a python file. +interface PyClass { + name: string; + startLine: number; + endLine: number; + methods: PyFn[] +} + +// PyFN represents a function parsed from a python file. This may belong to a class or +// it may be a top-level function. +interface PyFn { + name: string; + parameters: string; + startLine: number; + endLine: number; + body: string; +} + +export const listFilesTool = createTool({ + name: "list_files", + description: "Lists all files within the project, returned as a JSON string containign the path to each file", + handler: async (_input, opts) => { + // NOTE: In this repo, all files are stored in "./opt/" as the prefix. + const path = "./opt/" + opts.network?.state.kv.get("repo") + + const files = await opts.step.run("list files", () => { + return fs.readdirSync(path, { recursive: true }).filter(name => name.indexOf(".git") !== 0) + }); + + // Store all files within state. Note that this happens outside of steps + // so that this is not memoized. + opts.network && opts.network.state.kv.set("files", files); + return files; + }, +}); + +export const readFileTool = createTool({ + name: "read_file", + description: "Reads a single file given its filename, returning its contents", + parameters: z.object({ + filename: z.string(), + }), + handler: async ({ filename }, opts) => { + const content = await opts.step.run(`read file: ${filename}`, () => { + return readFile(opts.network?.state.kv.get("repo") || "", filename); + }) + + // Set state for the filename. Note that this happens outside of steps + // so that this is not memoized. + opts.network?.state.kv.set("file:" + filename, content); + return content; + }, +}); + +/** + * extractFnTool extracts all top level functions and classes from a Python file. It also + * parses all method definitions of a class. + * + */ +export const extractClassAndFns = createTool({ + name: "extract_classes_and_functions", + description: "Return all classes names and their functions, including top level functions", + parameters: z.object({ + filename: z.string(), + }), + handler: async (input, opts) => { + return await opts.step.run("parse file", () => { + const contents = readFile(opts.network?.state.kv.get("repo") || "", input.filename); + return parseClassAndFns(contents); + }); + }, +}); + +export const replaceClassMethodTool = createTool({ + name: "replace_class_method", + description: "Replaces a method within a specific class entirely.", + parameters: z.object({ + filename: z.string(), + class_name: z.string(), + function_name: z.string(), + new_contents: z.string(), + }), + handler: async ({ filename, class_name, function_name, new_contents }, opts) => { + const updated = await opts?.step.run(`update class method in '${filename}': ${class_name}.${function_name}`, () => { + // Re-parse the contents to find the correct start and end offsets. + const contents = readFile(opts.network?.state.kv.get("repo") || "", filename); + const parsed = parseClassAndFns(contents); + + const c = parsed.classes.find(c => class_name === c.name); + const fn = c?.methods.find(f => f.name === function_name); + if (!c || !fn) { + // TODO: Redo the planning as this wasn't found. + throw new Error("TODO: redo plan"); + } + + return contents.split("\n").reduce((updated, line, idx) => { + const beforeRange = (idx + 1) < fn.startLine; + const isRange = (idx + 1) === fn.startLine; + const afterRange = (idx + 1) >= fn.endLine; + + if (beforeRange || afterRange) { + return [...updated, line]; + } + + return isRange ? [...updated, new_contents] : updated; + }, [] as string[]).join("\n"); + }); + + const path = "./opt/" + opts.network?.state.kv.get("repo") + fs.writeFileSync(path + "/" + filename, updated); + + return new_contents; + }, +}) + + +// +// Utility functions +// + +export const readFile = (repo: string, filename: string) => { + // NOTE: In this repo, all files are stored in "./opt/" as the prefix. + const path = "./opt/" + repo + return fs.readFileSync(path + "/" + filename).toString(); +} + +export const parseClassAndFns = (contents: string) => { + const parser = new Parser(); + parser.setLanguage(Py); + + const tree = parser.parse(contents); + const cursor = tree.walk() + + const results = { + classes: [] as PyClass[], + functions: [] as PyFn[], + }; + + // Helper to get the full function name and parameters + const getFunctionDetails = (node: Parser.SyntaxNode): PyFn => { + const nameNode = node.childForFieldName('name'); + const parametersNode = node.childForFieldName('parameters'); + return { + name: nameNode?.text || "", + parameters: parametersNode?.text || "", + startLine: node.startPosition.row + 1, + endLine: node.endPosition.row + 1, + body: "", //node.text + }; + } + + const getClassMethods = (classNode: Parser.SyntaxNode) => { + const methods: PyFn[] = []; + + const body = classNode.childForFieldName("body"); + if (!body) { + return methods; + } + + const cursor = body.walk(); + cursor.gotoFirstChild(); + + do { + if (cursor.nodeType === 'function_definition') { + methods.push(getFunctionDetails(cursor.currentNode)); + } + } while(cursor.gotoNextSibling()); + + return methods; + } + + cursor.gotoFirstChild(); + do { + const node = cursor.currentNode; + if (!node) { + continue + } + + switch (node.type) { + case 'function_definition': + // Only process top-level functions + if (node.parent === tree.rootNode) { + results.functions.push(getFunctionDetails(node)); + } + break; + + case 'class_definition': + const classInfo: PyClass = { + name: node.childForFieldName('name')?.text || "", + startLine: node.startPosition.row + 1, + endLine: node.endPosition.row + 1, + methods: getClassMethods(node) + }; + results.classes.push(classInfo); + break; + } + } while(cursor.gotoNextSibling()); + + return results; +} diff --git a/src/adapters/anthropic.ts b/src/adapters/anthropic.ts index 640bd80..16e1ea1 100644 --- a/src/adapters/anthropic.ts +++ b/src/adapters/anthropic.ts @@ -9,8 +9,10 @@ import { type Anthropic, } from "inngest"; import { zodToJsonSchema } from "openai-zod-to-json-schema"; +import { type Tool } from "../types"; +import { z } from "zod"; import { type AgenticModel } from "../model"; -import { type InternalNetworkMessage } from "../state"; +import { type TextMessage, type Message } from "../state"; /** * Parse a request from internal network messages to an Anthropic input. @@ -19,25 +21,72 @@ export const requestParser: AgenticModel.RequestParser = ( model, messages, tools, + tool_choice = "auto", ) => { // Note that Anthropic has a top-level system prompt, then a series of prompts // for assistants and users. - const systemMessage = messages.find((m) => m.role === "system"); + const systemMessage = messages.find( + (m) => m.role === "system" && m.type === "text", + ) as TextMessage; const system = typeof systemMessage?.content === "string" ? systemMessage.content : ""; + const anthropicMessages: AiAdapter.Input["messages"] = + messages + .filter((m) => m.role !== "system") + .reduce( + (acc, m) => { + switch (m.type) { + case "text": + return [ + ...acc, + { + role: m.role, + content: Array.isArray(m.content) + ? m.content.map((text) => ({ type: "text", text })) + : m.content, + }, + ] as AiAdapter.Input["messages"]; + case "tool_call": + return [ + ...acc, + { + role: m.role, + content: m.tools.map((tool) => ({ + type: "tool_use", + id: tool.id, + input: tool.input, + name: tool.name, + })), + }, + ]; + case "tool_result": + return [ + ...acc, + { + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: m.tool.id, + content: + typeof m.content === "string" + ? m.content + : JSON.stringify(m.content), + }, + ], + }, + ]; + } + }, + [] as AiAdapter.Input["messages"], + ); + const request: AiAdapter.Input = { system, model: model.options.model, max_tokens: model.options.max_tokens, - messages: messages - .filter((m) => m.role !== "system") - .map((m) => { - return { - role: m.role, - content: m.content, - }; - }) as AiAdapter.Input["messages"], + messages: anthropicMessages, }; if (tools?.length) { @@ -45,11 +94,14 @@ export const requestParser: AgenticModel.RequestParser = ( return { name: t.name, description: t.description, - input_schema: zodToJsonSchema( - t.parameters, - ) as AnthropicAiAdapter.Tool.InputSchema, + input_schema: (t.parameters + ? zodToJsonSchema(t.parameters) + : zodToJsonSchema( + z.object({}), + )) as AnthropicAiAdapter.Tool.InputSchema, }; }); + request.tool_choice = toolChoice(tool_choice); } return request; @@ -61,54 +113,71 @@ export const requestParser: AgenticModel.RequestParser = ( export const responseParser: AgenticModel.ResponseParser = ( input, ) => { - return (input?.content ?? []).reduce( - (acc, item) => { - if (!item.type) { - return acc; - } - - switch (item.type) { - case "text": - return [ - ...acc, - { - role: input.role, - content: item.text, - }, - ]; - case "tool_use": { - let args; - try { - // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment - args = - typeof item.input === "string" - ? JSON.parse(item.input) - : item.input; - } catch { - args = item.input; - } + return (input?.content ?? []).reduce((acc, item) => { + if (!item.type) { + return acc; + } - return [ - ...acc, - { - role: input.role, - content: "", - tools: [ - { - type: "tool", - id: item.id, - name: item.name, - // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment - input: args, - }, - ], - }, - ]; + switch (item.type) { + case "text": + return [ + ...acc, + { + type: "text", + role: input.role, + content: item.text, + // XXX: Better stop reason parsing + stop_reason: "stop", + }, + ]; + case "tool_use": { + let args; + try { + // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + args = + typeof item.input === "string" + ? JSON.parse(item.input) + : item.input; + } catch { + args = item.input; } + + return [ + ...acc, + { + type: "tool_call", + role: input.role, + stop_reason: "tool", + tools: [ + { + type: "tool", + id: item.id, + name: item.name, + // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + input: args, + }, + ], + }, + ]; } + } + }, []); +}; - return acc; - }, - [], - ); +const toolChoice = ( + choice: Tool.Choice, +): AiAdapter.Input["tool_choice"] => { + switch (choice) { + case "auto": + return { type: "auto" }; + case "any": + return { type: "any" }; + default: + if (typeof choice === "string") { + return { + type: "tool", + name: choice as string, + }; + } + } }; diff --git a/src/adapters/openai.ts b/src/adapters/openai.ts index b3f1228..b4d4145 100644 --- a/src/adapters/openai.ts +++ b/src/adapters/openai.ts @@ -8,79 +8,13 @@ import { type AiAdapter, type OpenAi } from "inngest"; import { zodToJsonSchema } from "openai-zod-to-json-schema"; import { type AgenticModel } from "../model"; import { stringifyError } from "../util"; -import { type InternalNetworkMessage, type ToolMessage } from "../state"; - -/** - * Parse the given `str` `string` as JSON, also handling backticks, a common - * OpenAI quirk. - * - * @example Input - * ``` - * "{\n \"files\": [\n {\n \"filename\": \"fibo.ts\",\n \"content\": `\nfunction fibonacci(n: number): number {\n if (n < 2) {\n return n;\n } else {\n return fibonacci(n - 1) + fibonacci(n - 2);\n }\n}\n\nexport default fibonacci;\n`\n }\n ]\n}" - * ``` - */ -const safeParseOpenAIJson = (str: string): unknown => { - // Remove any leading/trailing quotes if present - const trimmed = str.replace(/^["']|["']$/g, ""); - - try { - // First try direct JSON parse - return JSON.parse(trimmed); - } catch { - try { - // Replace backtick strings with regular JSON strings - // Match content between backticks, preserving newlines - const withQuotes = trimmed.replace(/`([\s\S]*?)`/g, (_, content) => - JSON.stringify(content), - ); - return JSON.parse(withQuotes); - } catch (e) { - throw new Error( - `Failed to parse JSON with backticks: ${stringifyError(e)}`, - ); - } - } -}; - -const StateStopReasonToOpenAiStopReason: Record = { - tool: "tool_calls", - stop: "stop", -}; - -const OpenAiStopReasonToStateStopReason: Record = { - tool_calls: "tool", - stop: "stop", - length: "stop", - content_filter: "stop", - function_call: "tool", -}; - -const reqMsgRoleHandlers: Record< - InternalNetworkMessage["role"], - ( - internalMessage: InternalNetworkMessage, - ) => Partial["messages"][number]> -> = { - system: () => ({ role: "system" }), - user: () => ({ role: "user" }), - assistant: (m) => ({ - role: "assistant", - tool_calls: m.tools - ? m.tools?.map((tool) => ({ - id: tool.id, - type: "function", - function: { - name: tool.name, - arguments: JSON.stringify(tool.input), - }, - })) - : undefined, - }), - tool_result: (m) => ({ - role: "tool", - tool_call_id: m.tools?.[0]?.id, - }), -}; +import { type Tool } from "../types"; +import { + type TextMessage, + type ToolCallMessage, + type Message, + type ToolMessage, +} from "../state"; /** * Parse a request from internal network messages to an OpenAI input. @@ -89,25 +23,42 @@ export const requestParser: AgenticModel.RequestParser = ( model, messages, tools, + tool_choice = "auto", ) => { const request: AiAdapter.Input = { messages: messages.map((m) => { - const baseMsg = { - ...(m.stop_reason - ? { finish_reason: StateStopReasonToOpenAiStopReason[m.stop_reason] } - : {}), - content: m.content, - }; - - return { - ...baseMsg, - ...reqMsgRoleHandlers[m.role](m), - }; + switch (m.type) { + case "text": + return { + role: m.role, + content: m.content, + }; + case "tool_call": + return { + role: "assistant", + content: null, + tool_calls: m.tools + ? m.tools?.map((tool) => ({ + id: tool.id, + type: "function", + function: { + name: tool.name, + arguments: JSON.stringify(tool.input), + }, + })) + : undefined, + }; + case "tool_result": + return { + role: "tool", + content: m.content, + }; + } }) as AiAdapter.Input["messages"], }; if (tools?.length) { - request.tool_choice = "auto"; + request.tool_choice = toolChoice(tool_choice); // it is recommended to disable parallel tool calls with structured output // https://platform.openai.com/docs/guides/function-calling#parallel-function-calling-and-structured-outputs request.parallel_tool_calls = false; @@ -117,7 +68,7 @@ export const requestParser: AgenticModel.RequestParser = ( function: { name: t.name, description: t.description, - parameters: zodToJsonSchema(t.parameters), + parameters: t.parameters && zodToJsonSchema(t.parameters), strict: true, }, }; @@ -133,33 +84,100 @@ export const requestParser: AgenticModel.RequestParser = ( export const responseParser: AgenticModel.ResponseParser = ( input, ) => { - return (input?.choices ?? []).reduce( - (acc, choice) => { - if (!choice.message) { - return acc; - } + return (input?.choices ?? []).reduce((acc, choice) => { + const { message, finish_reason } = choice; + if (!message) { + return acc; + } - const stopReason = - OpenAiStopReasonToStateStopReason[choice.finish_reason ?? ""]; + const base = { + role: choice.message.role, + stop_reason: + openAiStopReasonToStateStopReason[finish_reason ?? ""] || "stop", + }; + if (message.content) { return [ ...acc, { - role: choice.message.role, - content: choice.message.content, - ...(stopReason ? { stop_reason: stopReason } : {}), - tools: (choice.message.tool_calls ?? []).map((tool) => { + ...base, + type: "text", + content: message.content, + } as TextMessage, + ]; + } + if (message.tool_calls.length > 0) { + return [ + ...acc, + { + ...base, + type: "tool_call", + tools: message.tool_calls.map((tool) => { return { - type: "function", + type: "tool", id: tool.id, name: tool.function.name, function: tool.function.name, input: safeParseOpenAIJson(tool.function.arguments || "{}"), - } as unknown as ToolMessage; // :( + } as ToolMessage; }), - } as InternalNetworkMessage, + } as ToolCallMessage, ]; - }, - [], - ); + } + return acc; + }, []); +}; + +/** + * Parse the given `str` `string` as JSON, also handling backticks, a common + * OpenAI quirk. + * + * @example Input + * ``` + * "{\n \"files\": [\n {\n \"filename\": \"fibo.ts\",\n \"content\": `\nfunction fibonacci(n: number): number {\n if (n < 2) {\n return n;\n } else {\n return fibonacci(n - 1) + fibonacci(n - 2);\n }\n}\n\nexport default fibonacci;\n`\n }\n ]\n}" + * ``` + */ +const safeParseOpenAIJson = (str: string): unknown => { + // Remove any leading/trailing quotes if present + const trimmed = str.replace(/^["']|["']$/g, ""); + + try { + // First try direct JSON parse + return JSON.parse(trimmed); + } catch { + try { + // Replace backtick strings with regular JSON strings + // Match content between backticks, preserving newlines + const withQuotes = trimmed.replace(/`([\s\S]*?)`/g, (_, content) => + JSON.stringify(content), + ); + return JSON.parse(withQuotes); + } catch (e) { + throw new Error( + `Failed to parse JSON with backticks: ${stringifyError(e)}`, + ); + } + } +}; + +const openAiStopReasonToStateStopReason: Record = { + tool_calls: "tool", + stop: "stop", + length: "stop", + content_filter: "stop", + function_call: "tool", +}; + +const toolChoice = (choice: Tool.Choice) => { + switch (choice) { + case "auto": + return "auto"; + case "any": + return "required"; + default: + return { + type: "function" as const, + function: { name: choice as string }, + }; + } }; diff --git a/src/agent.ts b/src/agent.ts index b984386..769648f 100644 --- a/src/agent.ts +++ b/src/agent.ts @@ -3,14 +3,10 @@ import { type Network } from "./network"; import { type State, InferenceResult, - type InternalNetworkMessage, + type Message, + type ToolResultMessage, } from "./state"; -import { - type BaseLifecycleArgs, - type BeforeLifecycleArgs, - type ResultLifecycleArgs, - type Tool, -} from "./types"; +import { type Tool } from "./types"; import { type AnyZodType, type MaybePromise } from "./util"; /** @@ -24,6 +20,9 @@ export const createTool = (t: Tool): Tool => t; */ export const createAgent = (opts: Agent.Constructor) => new Agent(opts); +export const createRoutingAgent = (opts: Agent.RoutingConstructor) => + new RoutingAgent(opts); + /** * Agent represents a single agent, responsible for a set of tasks. */ @@ -53,10 +52,20 @@ export class Agent { */ tools: Map; + /** + * tool_choice allows you to specify whether tools are automatically. this defaults + * to "auto", allowing the model to detect when to call tools automatically. Choices are: + * + * - "auto": allow the model to choose tools automatically + * - "any": force the use of any tool in the tools map + * - string: force the name of a particular tool + */ + tool_choice?: Tool.Choice; + /** * lifecycles are programmatic hooks used to manage the agent. */ - lifecycles: Agent.Lifecycle | undefined; + lifecycles: Agent.Lifecycle | Agent.RoutingLifecycle | undefined; /** * model is the step caller to use for this agent. This allows the agent @@ -65,12 +74,13 @@ export class Agent { */ model: AgenticModel.Any | undefined; - constructor(opts: Agent.Constructor) { + constructor(opts: Agent.Constructor | Agent.RoutingConstructor) { this.name = opts.name; this.description = opts.description || ""; this.system = opts.system; this.assistant = opts.assistant || ""; this.tools = new Map(); + this.tool_choice = opts.tool_choice; this.lifecycles = opts.lifecycle; this.model = opts.model; @@ -80,8 +90,15 @@ export class Agent { } withModel(model: AgenticModel.Any): Agent { - this.model = model; - return this; // for chaining + return new Agent({ + name: this.name, + description: this.description, + system: this.system, + assistant: this.assistant, + tools: Array.from(this.tools.values()), + lifecycle: this.lifecycles, + model, + }); } /** @@ -90,12 +107,7 @@ export class Agent { */ async run( input: string, - { - model, - network, - state: inputState, - maxIter = 0, - }: Agent.RunOptions | undefined = {}, + { model, network, state, maxIter = 0 }: Agent.RunOptions | undefined = {}, ): Promise { const p = model || this.model || network?.defaultModel; if (!p) { @@ -103,9 +115,9 @@ export class Agent { } // input state always overrides the network state. - const state = inputState || network?.state; + const s = state || network?.state; - let history = state ? state.format() : []; + let history = s ? s.format() : []; let prompt = await this.agentPrompt(input, network); let result = new InferenceResult(this, input, prompt, history, [], [], ""); let hasMoreActions = true; @@ -152,20 +164,24 @@ export class Agent { result = await this.lifecycles.onFinish({ agent: this, network, result }); } + // Note that the routing lifecycles aren't called by the agent. They're called + // by the network. + return result; } private async performInference( input: string, p: AgenticModel.Any, - prompt: InternalNetworkMessage[], - history: InternalNetworkMessage[], + prompt: Message[], + history: Message[], network?: Network, ): Promise { const { output, raw } = await p.infer( this.name, prompt.concat(history), Array.from(this.tools.values()), + this.tool_choice || "auto", ); // Now that we've made the call, we instantiate a new InferenceResult for @@ -197,13 +213,17 @@ export class Agent { } private async invokeTools( - msgs: InternalNetworkMessage[], + msgs: Message[], p: AgenticModel.Any, network?: Network, - ): Promise { - const output: InternalNetworkMessage[] = []; + ): Promise { + const output: ToolResultMessage[] = []; for (const msg of msgs) { + if (msg.type !== "tool_call") { + continue; + } + if (!Array.isArray(msg.tools)) { continue; } @@ -233,16 +253,16 @@ export class Agent { output.push({ role: "tool_result", - tools: [ - { - type: "tool", - id: tool.id, - name: tool.name, - input: tool.input.arguments as Record, - }, - ], - // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + type: "tool_result", + tool: { + type: "tool", + id: tool.id, + name: tool.name, + input: tool.input.arguments as Record, + }, + content: result ? result : `${tool.name} successfully executed`, + stop_reason: "tool", }); } } @@ -253,13 +273,14 @@ export class Agent { private async agentPrompt( input: string, network?: Network, - ): Promise { + ): Promise { // Prompt returns the full prompt for the current agent. This does NOT // include the existing network's state as part of the prompt. // // Note that the agent's system message always comes first. - const messages: InternalNetworkMessage[] = [ + const messages: Message[] = [ { + type: "text", role: "system", content: typeof this.system === "string" @@ -269,17 +290,42 @@ export class Agent { ]; if (input.length > 0) { - messages.push({ role: "user", content: input }); + messages.push({ type: "text", role: "user", content: input }); } if (this.assistant.length > 0) { - messages.push({ role: "assistant", content: this.assistant }); + messages.push({ + type: "text", + role: "assistant", + content: this.assistant, + }); } return messages; } } +export class RoutingAgent extends Agent { + type = "routing"; + override lifecycles: Agent.RoutingLifecycle; + constructor(opts: Agent.RoutingConstructor) { + super(opts); + this.lifecycles = opts.lifecycle; + } + + override withModel(model: AgenticModel.Any): RoutingAgent { + return new RoutingAgent({ + name: this.name, + description: this.description, + system: this.system, + assistant: this.assistant, + tools: Array.from(this.tools.values()), + lifecycle: this.lifecycles, + model, + }); + } +} + export namespace Agent { export interface Constructor { name: string; @@ -287,10 +333,15 @@ export namespace Agent { system: string | ((network?: Network) => MaybePromise); assistant?: string; tools?: Tool.Any[]; + tool_choice?: Tool.Choice; lifecycle?: Lifecycle; model?: AgenticModel.Any; } + export interface RoutingConstructor extends Omit { + lifecycle: RoutingLifecycle; + } + export interface RunOptions { model?: AgenticModel.Any; network?: Network; @@ -308,7 +359,7 @@ export namespace Agent { * enabled selectively enables or disables this agent based off of network * state. If this function is not provided, the agent is always enabled. */ - enabled?: (args: BaseLifecycleArgs) => MaybePromise; + enabled?: (args: Agent.LifecycleArgs.Base) => MaybePromise; /** * onStart is called just before an agent starts an inference call. @@ -321,9 +372,9 @@ export namespace Agent { * the agent from making the call altogether. * */ - onStart?: (args: BeforeLifecycleArgs) => MaybePromise<{ - prompt: InternalNetworkMessage[]; - history: InternalNetworkMessage[]; + onStart?: (args: Agent.LifecycleArgs.Before) => MaybePromise<{ + prompt: Message[]; + history: Message[]; // stop, if true, will prevent calling the agent stop: boolean; }>; @@ -333,7 +384,9 @@ export namespace Agent { * have been invoked. This allows you to moderate the response prior to * running tools. */ - onResponse?: (args: ResultLifecycleArgs) => MaybePromise; + onResponse?: ( + args: Agent.LifecycleArgs.Result, + ) => MaybePromise; /** * onFinish is called with a finalized InferenceResult, including any tool @@ -341,6 +394,45 @@ export namespace Agent { * history, if the agent is part of the network. * */ - onFinish?: (args: ResultLifecycleArgs) => MaybePromise; + onFinish?: ( + args: Agent.LifecycleArgs.Result, + ) => MaybePromise; + } + + export namespace LifecycleArgs { + export interface Base { + // Agent is the agent that made the call. + agent: Agent; + // Network represents the network that this agent or lifecycle belongs to. + network?: Network; + } + + export interface Result extends Base { + result: InferenceResult; + } + + export interface Before extends Base { + // input is the user request for the entire agentic operation. + input?: string; + + // prompt is the system, user, and any assistant prompt as generated + // by the Agent. This does not include any past history. + prompt: Message[]; + + // history is the past history as generated via State. Ths will be added + // after the prompt to form a single conversation log. + history?: Message[]; + } + } + + export interface RoutingLifecycle extends Lifecycle { + onRoute: RouterFn; } + + export type RouterFn = (args: Agent.RouterArgs) => string[] | undefined; + + /** + * Router args are the arguments passed to the onRoute lifecycle hook. + */ + export type RouterArgs = Agent.LifecycleArgs.Result; } diff --git a/src/model.ts b/src/model.ts index dd5198b..2cac595 100644 --- a/src/model.ts +++ b/src/model.ts @@ -1,5 +1,5 @@ import { type AiAdapter, type GetStepTools, type Inngest } from "inngest"; -import { type InternalNetworkMessage } from "./state"; +import { type Message } from "./state"; import { type Tool } from "./types"; export class AgenticModel { @@ -23,12 +23,13 @@ export class AgenticModel { async infer( stepID: string, - input: InternalNetworkMessage[], + input: Message[], tools: Tool.Any[], + tool_choice: Tool.Choice, ): Promise { const result = (await this.step.ai.infer(stepID, { model: this.#model, - body: this.requestParser(this.#model, input, tools), + body: this.requestParser(this.#model, input, tools, tool_choice), })) as AiAdapter.Input; return { output: this.responseParser(result), raw: result }; @@ -44,7 +45,7 @@ export namespace AgenticModel { * result depending on the model's API repsonse. */ export type InferenceResponse = { - output: InternalNetworkMessage[]; + output: Message[]; raw: T; }; @@ -57,11 +58,12 @@ export namespace AgenticModel { export type RequestParser = ( model: TAiAdapter, - state: InternalNetworkMessage[], + state: Message[], tools: Tool.Any[], + tool_choice: Tool.Choice, ) => AiAdapter.Input; export type ResponseParser = ( output: AiAdapter.Output, - ) => InternalNetworkMessage[]; + ) => Message[]; } diff --git a/src/models/anthropic.ts b/src/models/anthropic.ts index db57042..33131c5 100644 --- a/src/models/anthropic.ts +++ b/src/models/anthropic.ts @@ -1,11 +1,7 @@ -import { - anthropic as ianthropic, - type GetStepTools, - type Inngest, - type Anthropic, -} from "inngest"; +import { anthropic as ianthropic, type Anthropic } from "inngest"; import { requestParser, responseParser } from "../adapters/anthropic"; import { AgenticModel } from "../model"; +import { type AnyStepTools } from "../types"; export namespace AnthropicModel { export interface Options @@ -18,7 +14,7 @@ export namespace AnthropicModel { /** * The step tools to use internally within this model. */ - step: GetStepTools; + step: AnyStepTools; } } diff --git a/src/network.ts b/src/network.ts index 19e1e46..9870b93 100644 --- a/src/network.ts +++ b/src/network.ts @@ -1,5 +1,10 @@ import { z } from "zod"; -import { Agent, createTool } from "./agent"; +import { + type Agent, + RoutingAgent, + createRoutingAgent, + createTool, +} from "./agent"; import { type AgenticModel } from "./model"; import { type InferenceResult, State } from "./state"; import { type MaybePromise } from "./util"; @@ -98,25 +103,27 @@ export class Network { * run handles a given request using the network of agents. It is not * concurrency-safe; you can only call run on a network once, as networks are * stateful. + * */ async run(input: string, router?: Network.Router): Promise { - const agents = await this.availableAgents(); - - if (agents.length === 0) { + const available = await this.availableAgents(); + if (available.length === 0) { throw new Error("no agents enabled in network"); } // If there's no default agent used to run the request, use our internal // routing agent which attempts to figure out the best agent to choose based // off of the network. - const agent = await this.getNextAgent(router); - if (!agent) { + const next = await this.getNextAgents(input, router); + if (!next) { // TODO: If call count is 0, error. return this; } // Schedule the agent to run on our stack, then start popping off the stack. - this.schedule(agent.name); + for (const agent of next) { + this.schedule(agent.name); + } while ( this._stack.length > 0 && @@ -153,31 +160,39 @@ export class Network { // By default, this is an agentic router which takes the current state, // agents, then figures out next steps. This can, and often should, be // custom code. - const next = await this.getNextAgent(router); - if (next) { - this.schedule(next.name); + const next = await this.getNextAgents(input, router); + for (const a of next || []) { + this.schedule(a.name); } } return this; } - private async getNextAgent( + private async getNextAgents( + input: string, router?: Network.Router, - ): Promise { - const defaultModel = this.defaultModel; + ): Promise { + // A router may do one of two things: + // + // 1. Return one or more Agents to run + // 2. Return undefined, meaning we're done. + // + // It can do this by using code, or by calling routing agents directly. + if (!router && !this.defaultModel) { + throw new Error( + "No router or model defined in network. You must pass a router or a default model to use the built-in agentic router.", + ); + } if (!router) { - if (!defaultModel) { - throw new Error( - "No router or model defined in network. You must pass a router or a default model to use the built-in agentic router.", - ); - } - - return defaultRoutingAgent.withModel(defaultModel); - } else if (router instanceof Agent) { - return router; + router = defaultRoutingAgent; + } + if (router instanceof RoutingAgent) { + return await this.getNextAgentsViaRoutingAgent(router, input); } + // This is a function call which determines the next agent to call. Note that the result + // of this function call may be another RoutingAgent. const stack: Agent[] = this._stack.map((name) => { const agent = this._agents.get(name); if (!agent) { @@ -187,42 +202,74 @@ export class Network { }); const agent = await router({ + input, network: this, stack, lastResult: this.state.results.pop(), callCount: this._counter, }); - if (!agent) { return; } + if (agent instanceof RoutingAgent) { + // Functions may also return routing agents. + return await this.getNextAgentsViaRoutingAgent(agent, input); + } - // Ensure this agent is part of the network. If not, we're going to - // automatically add it. - if (!this._agents.has(agent.name)) { - // XXX: Add a warning here. - this._agents.set(agent.name, agent); + for (const a of Array.isArray(agent) ? agent : [agent]) { + // Ensure this agent is part of the network. If not, we're going to + // automatically add it. + if (!this._agents.has(a.name)) { + this._agents.set(a.name, a); + } } - return agent; + return Array.isArray(agent) ? agent : [agent]; + } + + private async getNextAgentsViaRoutingAgent( + routingAgent: RoutingAgent, + input: string, + ): Promise { + const result = await routingAgent.run(input, { + network: this, + model: routingAgent.model || this.defaultModel, + }); + const agentNames = routingAgent.lifecycles.onRoute({ + result, + agent: routingAgent, + network: this, + }); + + return (agentNames || []) + .map((name) => this.agents.get(name)) + .filter(Boolean) as Agent[]; } } /** - * RoutingAgent is an AI agent that selects the appropriate agent from the - * network to handle the incoming request. + * defaultRoutingAgent is an AI agent that selects the appropriate agent from + * the network to handle the incoming request. + * + * It is no set model and so relies on the presence of a default model in the + * network or being explicitly given one. */ -export const defaultRoutingAgent = new Agent({ +export const defaultRoutingAgent = createRoutingAgent({ name: "Default routing agent", + description: "Selects which agents to work on based off of the current prompt and input.", lifecycle: { - onFinish: ({ result }) => { - // We never want to store this call's instructions in history. - result.withFormatter(() => []); - - return result; + onRoute: ({ result }) => { + const tool = result.toolCalls[0]; + if (!tool) { + return; + } + if (typeof tool.content === "string") { + return [tool.content]; + } + return; }, }, @@ -258,14 +305,15 @@ export const defaultRoutingAgent = new Agent({ ); } - // Schedule another agent. - network.schedule(agent.name); - + // This returns the agent name to call. The default routing functon + // schedules this agent by inpsecting this name via the tool call output. return agent.name; }, }), ], + tool_choice: "select_agent", + system: async (network?: Network): Promise => { if (!network) { throw new Error( @@ -294,11 +342,9 @@ The following agents are available: Follow the set of instructions: - Think about the current history and status. Determine which agent to use to handle the user's request. Respond with the agent's name within a tag as content, and select the appropriate tool. + Think about the current history and status. Determine which agent to use to handle the user's request, based off of the current agents and their tools. Your aim is to thoroughly complete the request, thinking step by step, choosing the right agent based off of the context. - - If the request has been solved, respond with one single tag, with the answer inside: $answer `; }, @@ -316,22 +362,34 @@ export namespace Network { /** * Router defines how a network coordinates between many agents. A router is - * a single function that gets given the network, current state, future - * agentic calls, and the last inference result from the network. + * either a RoutingAgent which uses inference calls to choose the next Agent, + * or a function which chooses the next Agent to call. * - * You can choose to create semi-autonomous networks by writing standard - * deterministic code to call agents based off of the current state. + * The function gets given the network, current state, future + * agentic calls, and the last inference result from the network. * - * You can also choose to create fully autonomous agentic networks by calling - * a "routing agent", which determines the best agent to call based off of - * current state. */ - export type Router = - | Agent - | ((args: Router.Args) => MaybePromise); + export type Router = RoutingAgent | Router.FnRouter; export namespace Router { + /** + * FnRouter defines a function router which returns an Agent, an AgentRouter, or + * undefined if the network should stop. + * + * If the FnRouter returns an AgentRouter (an agent with the .route function), + * the agent will first be ran, then the `.route` function will be called. + * + */ + export type FnRouter = ( + args: Args, + ) => MaybePromise; + export interface Args { + /** + * input is the input called to the network + */ + input: string; + /** * Network is the network that this router is coordinating. Network state * is accessible via `network.state`. diff --git a/src/state.ts b/src/state.ts index d0ab2bd..4838665 100644 --- a/src/state.ts +++ b/src/state.ts @@ -1,9 +1,15 @@ import { type Agent } from "./agent"; -export interface InternalNetworkMessage { - role: "system" | "user" | "assistant" | "tool_result"; - content: string | Array | ToolResult; - tools?: ToolMessage[]; +export type Message = TextMessage | ToolCallMessage | ToolResultMessage; + +/** + * TextMessage represents plain text messages in the chat history, eg. the user's prompt or + * an assistant's reply. + */ +export interface TextMessage { + type: "text"; + role: "system" | "user" | "assistant"; + content: string | Array; // Anthropic: // stop_reason: "end_turn" | "max_tokens" | "stop_sequence" | "tool_use" | null; // OpenAI: @@ -11,22 +17,41 @@ export interface InternalNetworkMessage { stop_reason?: "tool" | "stop"; } -export interface TextMessage { +/** + * ToolCallMessage represents a message for a tool call. + */ +export interface ToolCallMessage { + type: "tool_call"; + role: "user" | "assistant"; + tools: ToolMessage[]; + stop_reason: "tool"; +} + +/** + * ToolResultMessage represents the output of a tool call. + */ +export interface ToolResultMessage { + type: "tool_result"; + role: "tool_result"; + // tool contains the tool call request for this result. + tool: ToolMessage; + content: unknown; + stop_reason: "tool"; +} + +// Message content. + +export interface TextContent { type: "text"; text: string; } + export interface ToolMessage { type: "tool"; id: string; name: string; input: Record; } -export interface ToolResult { - type: "tool_result"; - id: string; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - content: any; -} // TODO: Content types. /** * State stores state (history) for a given network of agents. The state @@ -88,7 +113,7 @@ export class State { * calling an individual agent. * */ - format(): InternalNetworkMessage[] { + format(): Message[] { return this._history.map((call) => call.format()).flat(); } @@ -109,9 +134,7 @@ export class InferenceResult { // You can set a custom history adapter by calling .withFormatter() within // lifecycles. This allows you to change how future agentic calls interpret // past agentic calls. - private _historyFormatter: - | ((a: InferenceResult) => InternalNetworkMessage[]) - | undefined; + private _historyFormatter: ((a: InferenceResult) => Message[]) | undefined; constructor( // agent represents the agent for this inference call. @@ -123,30 +146,30 @@ export class InferenceResult { // prompt represents the input instructions - without any additional history // - as created by the agent. This includes the system prompt, the user input, // and any initial agent assistant message. - public prompt: InternalNetworkMessage[], + public prompt: Message[], // history represents the history sent to the inference call, appended to the // prompt to form a complete conversation log - public history: InternalNetworkMessage[], + public history: Message[], // output represents the parsed output from the inference call. This may be blank // if the agent responds with tool calls only. - public output: InternalNetworkMessage[], + public output: Message[], // toolCalls represents output from any tools called by the agent. - public toolCalls: InternalNetworkMessage[], + public toolCalls: ToolResultMessage[], // raw represents the raw API response from the call. This is a JSON // string, and the format depends on the agent's model. public raw: string, ) {} - withFormatter(f: (a: InferenceResult) => InternalNetworkMessage[]) { + withFormatter(f: (a: InferenceResult) => Message[]) { this._historyFormatter = f; } // format - format(): InternalNetworkMessage[] { + format(): Message[] { if (this._historyFormatter) { return this._historyFormatter(this); } @@ -160,25 +183,29 @@ export class InferenceResult { // prompts. const agent = this.agent; - const messages: InternalNetworkMessage[] = this.prompt.map(function (msg) { - let content: string; - if (typeof msg.content === "string") { - content = msg.content; - } else if (Array.isArray(msg.content)) { - content = msg.content.map((m) => m.text).join("\n"); - } else { - // XXX: Type checking here. - content = msg.content.content as string; - } - - // Ensure that system prompts are always as an assistant in history - return { - ...msg, - role: "assistant", - content: `${agent.name}\n${content}`, - }; - }); - - return messages.concat(this.output).concat(this.toolCalls); + const messages = this.prompt + .map((msg) => { + if (msg.type !== "text") { + return; + } + + let content: string = ""; + if (typeof msg.content === "string") { + content = msg.content; + } else if (Array.isArray(msg.content)) { + content = msg.content.map((m) => m.text).join("\n"); + } + + // Ensure that system prompts are always as an assistant in history + return { + ...msg, + type: "text", + role: "assistant", + content: `${agent.name}\n${content}`, + }; + }) + .filter(Boolean); + + return (messages as Message[]).concat(this.output).concat(this.toolCalls); } } diff --git a/src/types.ts b/src/types.ts index 4b310e5..79db5fd 100644 --- a/src/types.ts +++ b/src/types.ts @@ -2,7 +2,6 @@ import { type GetStepTools, type Inngest } from "inngest"; import { type output as ZodOutput } from "zod"; import { type Agent } from "./agent"; import { type Network } from "./network"; -import { type InferenceResult, type InternalNetworkMessage } from "./state"; import { type GenericizeFunctionsInObject, type AnyZodType, @@ -13,7 +12,7 @@ import { export type Tool = { name: string; description?: string; - parameters: T; + parameters?: T; // TODO: Handler input types based off of JSON above. // @@ -26,6 +25,8 @@ export type Tool = { export namespace Tool { export type Any = Tool; + + export type Choice = "auto" | "any" | (string & {}); } export type ToolHandlerArgs = { @@ -34,30 +35,6 @@ export type ToolHandlerArgs = { step: GetStepTools; }; -export interface BaseLifecycleArgs { - // Agent is the agent that made the call. - agent: Agent; - // Network represents the network that this agent or lifecycle belongs to. - network?: Network; -} - -export interface ResultLifecycleArgs extends BaseLifecycleArgs { - result: InferenceResult; -} - -export interface BeforeLifecycleArgs extends BaseLifecycleArgs { - // input is the user request for the entire agentic operation. - input?: string; - - // prompt is the system, user, and any assistant prompt as generated - // by the Agent. This does not include any past history. - prompt: InternalNetworkMessage[]; - - // history is the past history as generated via State. Ths will be added - // after the prompt to form a single conversation log. - history?: InternalNetworkMessage[]; -} - /** * Represents step tooling from an Inngest client, purposefully genericized to * allow for more flexible usage.