Skip to content

Commit

Permalink
feat: split
Browse files Browse the repository at this point in the history
  • Loading branch information
KidsXH committed Feb 2, 2024
1 parent 2c05100 commit 73d3ac6
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 36 deletions.
18 changes: 12 additions & 6 deletions src/components/VisView/outline.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@ import {
setFocusChatID,
setMainChannelID,
} from "@/store/chatSlice";
import { clickNode } from "@/store/selectionSlice";
import {
clickNode,
updateSelectedCodeRangeOnTree,
} from "@/store/selectionSlice";
import React from "react";
import MenuItem from "@mui/material/MenuItem";
import { StyledMenu } from "./styleMenu";
import { setCommand } from "@/store/modelSlice";

export type TreeNode = {
requestID: number[];
Expand Down Expand Up @@ -64,11 +68,13 @@ const OutlineView = () => {
if (parentID === undefined) {
return;
}
dispatch(
setFocusChatID(
treeNodes[parentID].requestID[treeNodes[parentID].requestID.length - 1],
),
);
const parentRequestID =
treeNodes[parentID].requestID[treeNodes[parentID].requestID.length - 1];
const codeRange = node.label.split("-").map((n) => Number(n));
dispatch(setMainChannelID(requestPool[parentRequestID].channelID));
dispatch(setFocusChatID(parentRequestID));
dispatch(updateSelectedCodeRangeOnTree([codeRange[0], codeRange[1]]));
dispatch(setCommand("next-split"));
handleClose();
};
const handleTrim = (treeNodeID: number) => {};
Expand Down
46 changes: 32 additions & 14 deletions src/hooks/useChatHistory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { useAppDispatch, useAppSelector } from "@/hooks/redux";
import {
addChat,
changeChannelStatus,
ChannelStatus,
ChannelStatus, selectFocusChatID,
selectMainChannelChats,
selectMainChannelID,
selectNumChats,
Expand Down Expand Up @@ -41,6 +41,7 @@ export const useChatHistory = () => {
const numNodes = useAppSelector(selectNumNodes);
const numHandledRequests = useAppSelector(selectNumHandledRequests);
const mainChannelID = useAppSelector(selectMainChannelID);
const focusChatID = useAppSelector(selectFocusChatID);

useEffect(() => {
if (requestPool.length > numChats) {
Expand Down Expand Up @@ -90,20 +91,37 @@ export const useChatHistory = () => {
dispatch(updateCodeRange(sourceCode));
}, [sourceCode, nodePool, numNodes, numRequests, dispatch]);

const numRuns = useAppSelector(selectNumRuns);
useEffect(() => {
if (numRuns > 0 || nodePool.length === 0) return;
let latestNode = nodePool[0];
// const numRuns = useAppSelector(selectNumRuns);
// useEffect(() => {
// if (numRuns > 0 || nodePool.length === 0) return;
//
//
// let latestNode = nodePool[0];
//
// nodePool.forEach((node) => {
// const requestID = node.id;
// if (requestPool[requestID].channelID === mainChannelID) {
// latestNode = node;
// }
// });
//
// console.log("[useChatHis] Update FocusChatID", mainChannelID, nodePool, latestNode)
//
// if (latestNode) dispatch(setFocusChatID(latestNode.id));
// }, [numNodes, nodePool, numRuns, dispatch, requestPool, mainChannelID]);

nodePool.forEach((node) => {
const requestID = node.id;
if (requestPool[requestID].channelID === mainChannelID) {
latestNode = node;
useEffect(() => {
// When the main channel is changed, update the focus chat ID
if (requestPool[focusChatID]?.channelID !== mainChannelID) {
const nodes = nodePool.filter(
(node) => requestPool[node.id].channelID === mainChannelID,
);
if (nodes.length > 0) {
const lastNode = nodes[nodes.length - 1];
dispatch(setFocusChatID(lastNode.id));
}
});

if (latestNode) dispatch(setFocusChatID(latestNode.id));
}, [numNodes, nodePool, numRuns, dispatch, requestPool, mainChannelID]);
}
}, [mainChannelID]);
};

export const saveRequestMessages = (
Expand Down Expand Up @@ -162,7 +180,7 @@ const chat2node = (
assistant.tool_calls ? assistant.tool_calls[0].function.arguments : "{}",
);

console.log("[useChatHis]functionArgs", functionArgs);
// console.log("[useChatHis]functionArgs", functionArgs);

const text =
functionName === "writeExplanation"
Expand Down
87 changes: 73 additions & 14 deletions src/hooks/usePlannerCommands.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ import { useEffect, useMemo } from "react";
import { usePlannerContext } from "@/providers/Planner";
import {
selectNodePool,
selectNumNodes,
selectNumRequests,
selectRequestPool,
} from "@/store/nodeSlice";
import { saveRequestMessages } from "@/hooks/useChatHistory";
Expand All @@ -27,10 +25,13 @@ import {
selectFocusChatID,
selectMainChannelChats,
selectMainChannelID,
selectNumChannels,
selectNumChannels, setFocusChatID,
setMainChannelID,
} from "@/store/chatSlice";
import { selectSelectedCodeRange } from "@/store/selectionSlice";
import {
selectSelectedCodeRange,
selectSelectedCodeRangeOnTree,
} from "@/store/selectionSlice";

const usePlannerCommands = () => {
const dispatch = useAppDispatch();
Expand All @@ -46,14 +47,15 @@ const usePlannerCommands = () => {
const focusChatID = useAppSelector(selectFocusChatID);
const mainChannelChats = useAppSelector(selectMainChannelChats);
const selectedCodeRange = useAppSelector(selectSelectedCodeRange);
const selectedCodeRangeOnTree = useAppSelector(selectSelectedCodeRangeOnTree);

const requestMemory = useMemo(() => {
let requests = requestPool;
if (focusChatID !== -1) {
requests = requests.slice(0, focusChatID + 2);
}
return requests.filter((request) => request.channelID === mainChannelID);
}, [focusChatID]);
}, [focusChatID, mainChannelID, requestPool]);

const tutorialMemory = useMemo(() => {
let nodes = nodePool;
Expand All @@ -66,25 +68,25 @@ const usePlannerCommands = () => {
return nodes.filter(
(node) => requestPool[node.id].channelID === mainChannelID,
);
}, [focusChatID]);
}, [focusChatID, mainChannelID, nodePool, requestPool]);

const [planners] = usePlannerContext();

const continue2next = (numThoughts?: number) => {
console.log('[Memory]', requestMemory, tutorialMemory);
numThoughts = numThoughts || 3;
dispatch(setNumRuns(numThoughts));
dispatch(clearActiveChannels());
const isLastChatNode =
focusChatID === -1 ||
focusChatID === mainChannelChats[mainChannelChats.length - 2];

let newChannelID = numChannels;

for (let i = 0; i < numThoughts; i++) {
const planner = planners[i];
let channel = numChannels + i - 1;
let channel = mainChannelID;

if (i === 0) {
// channel = mainChannelID;
if (isLastChatNode) channel = mainChannelID;
else channel = numChannels + numThoughts - 1;
if (i > 0) {
channel = newChannelID;
newChannelID += 1;
}

planner.initialize(sourceCode, channel);
Expand Down Expand Up @@ -171,6 +173,48 @@ const usePlannerCommands = () => {
}
};

const nextSplit = () => {
console.log('[Memory]', requestMemory, tutorialMemory);
const numThoughts = 1;
dispatch(setNumRuns(numThoughts));
dispatch(clearActiveChannels());
const planner = planners[0];
const channel = numChannels;
planner.initialize(sourceCode, channel);
planner.setMemory(
requestMemory.map((request) => request.request),
tutorialMemory.map((node) => node.action),
);
const planPrompt = planner.planPrompt4Split(
sourceCode,
selectedCodeRangeOnTree,
);
dispatch(
activateChannel({
channelID: channel,
isActive: true,
isDone: false,
lastChatNodeID: -1,
}),
);
planner
.nextWithPlan(planPrompt)
.then((res) => {
const { hasNext, id } = res;
if (hasNext) {
saveRequestMessages(planners[id], requestPool, dispatch);
} else {
dispatch(deactivateChannel(planners[id].channel));
}
dispatch(decreaseNumRuns());
})
.catch((err) => {
console.log("[Planner Error]", err);
dispatch(decreaseNumRuns());
dispatch(setCommand("pause"));
});
};

const vote = (channels: ChannelStatus[]) => {
const voteMap: Map<string, number[]> = new Map();
let maxVotes = 0;
Expand Down Expand Up @@ -213,10 +257,17 @@ const usePlannerCommands = () => {
const bestResult = vote(
activeChannels.filter((channel) => channel.isDone),
);
const bestIdx = activeChannels.findIndex(
(channel) => channel.channelID === bestResult,
);
if (bestResult === undefined) {
dispatch(setCommand("pause"));
} else {
const newChatNode = nodePool.find(
(node) => node.id === activeChannels[bestIdx].lastChatNodeID,
);
dispatch(setMainChannelID(bestResult));
dispatch(setFocusChatID(newChatNode?.id || -1));
dispatch(setRunningState("waited"));
dispatch(setCommand("continue-next"));
}
Expand Down Expand Up @@ -258,6 +309,14 @@ const usePlannerCommands = () => {
}
}

if (command === "next-split") {
if (runningState === "paused") {
dispatch(setRunningState("running"));
nextSplit();
dispatch(setCommand("none"));
}
}

if (command === "continue-next") {
if (runningState === "waited") {
dispatch(setRunningState("running"));
Expand Down
9 changes: 8 additions & 1 deletion src/models/agents/planner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ export class Planner {
}

this.llm.chatMessages.push(responseMessage);
console.log("[Planner] responseMessage", responseMessage);
// console.log("[Planner] responseMessage", responseMessage);

if (responseMessage.tool_calls) {
const functionName = responseMessage.tool_calls[0].function.name || "";
Expand Down Expand Up @@ -141,6 +141,13 @@ export class Planner {
const codeSnippet = codeLines.slice(start - 1, end).join("\n");
return `You are supposed to explain the code \`\`\`${codeSnippet}\`\`\` in the next step. Please write the observation, thought, and action for the next step.`;
}

planPrompt4Split(code: string, lineNumber: [number, number]) {
const [start, end] = lineNumber;
const codeLines = code.split("\n");
const codeSnippet = codeLines.slice(start - 1, end).join("\n");
return `You are supposed to explain the code \`\`\`${codeSnippet}\`\`\` in the next multiple steps. Please write the observation, thought, and action for the next steps.`;
}
}

export const parseMessage = (
Expand Down
1 change: 1 addition & 0 deletions src/store/modelSlice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ interface modelState {
| "continue"
| "continue-next"
| "next-plan"
| "next-split"
| "finish"
| "none";
modelName: string;
Expand Down
13 changes: 12 additions & 1 deletion src/store/selectionSlice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ interface SelectionState {
chainNodes: number[]; // The id of the nodes in the chain
clickNodeTrigger: boolean;
selectedCodeRange: [number, number];
selectedCodeRangeOnTree: [number, number];
}

const initialState: SelectionState = {
chainNodes: [],
clickNodeTrigger: false,
selectedCodeRange: [0, 0],
selectedCodeRangeOnTree: [0, 0],
};

export const selectionSlice = createSlice({
Expand All @@ -30,10 +32,16 @@ export const selectionSlice = createSlice({
) => {
state.selectedCodeRange = [...action.payload];
},
updateSelectedCodeRangeOnTree: (
state,
action: PayloadAction<[number, number]>,
) => {
state.selectedCodeRangeOnTree = [...action.payload];
},
},
});

export const { pickChain, clickNode, updateSelectedCodeRange } =
export const { pickChain, clickNode, updateSelectedCodeRange, updateSelectedCodeRangeOnTree } =
selectionSlice.actions;

export const selectChainNodes = (state: RootState) =>
Expand All @@ -45,4 +53,7 @@ export const selectClickNodeTrigger = (state: RootState) =>
export const selectSelectedCodeRange = (state: RootState) =>
state.selection.selectedCodeRange;

export const selectSelectedCodeRangeOnTree = (state: RootState) =>
state.selection.selectedCodeRangeOnTree;

export default selectionSlice.reducer;

0 comments on commit 73d3ac6

Please sign in to comment.