Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(messages): reuse the first valid message ID for subsequent chunks #798

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions libs/langgraph/src/pregel/messages.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { v4 } from "uuid";
import {
BaseCallbackHandler,
HandleLLMNewTokenCallbackFields,
Expand Down Expand Up @@ -44,28 +43,42 @@ export class StreamMessagesHandler extends BaseCallbackHandler {

emittedChatModelRunIds: Record<string, boolean> = {};

stableMessageIdMap: Record<string, string> = {};

lc_prefer_streaming = true;

constructor(streamFn: (streamChunk: StreamChunk) => void) {
super();
this.streamFn = streamFn;
}

_emit(meta: Meta, message: BaseMessage, dedupe = false) {
_emit(meta: Meta, message: BaseMessage, runId: string, dedupe = false) {
if (
dedupe &&
message.id !== undefined &&
this.seen[message.id] !== undefined
) {
return;
}
if (message.id === undefined) {
const id = v4();

// For instance in ChatAnthropic, the first chunk has an message ID
// but the subsequent chunks do not. To avoid clients seeing two messages
// we rename the message ID if it's being auto-set to `run-${runId}`
// (see https://github.com/langchain-ai/langchainjs/pull/6646).
let messageId = message.id;
if (messageId == null || messageId === `run-${runId}`) {
messageId = this.stableMessageIdMap[runId] ?? messageId ?? `run-${runId}`;
}
this.stableMessageIdMap[runId] ??= messageId;

if (messageId !== message.id) {
// eslint-disable-next-line no-param-reassign
message.id = id;
message.id = messageId;

// eslint-disable-next-line no-param-reassign
message.lc_kwargs.id = id;
message.lc_kwargs.id = messageId;
}

this.seen[message.id!] = message;
this.streamFn([meta[0], "messages", [message, meta[1]]]);
}
Expand Down Expand Up @@ -104,13 +117,12 @@ export class StreamMessagesHandler extends BaseCallbackHandler {
this.emittedChatModelRunIds[runId] = true;
if (this.metadatas[runId] !== undefined) {
if (isChatGenerationChunk(chunk)) {
this._emit(this.metadatas[runId], chunk.message);
this._emit(this.metadatas[runId], chunk.message, runId);
} else {
this._emit(
this.metadatas[runId],
new AIMessageChunk({
content: token,
})
new AIMessageChunk({ content: token }),
runId
);
}
}
Expand All @@ -121,11 +133,12 @@ export class StreamMessagesHandler extends BaseCallbackHandler {
if (!this.emittedChatModelRunIds[runId]) {
const chatGeneration = output.generations?.[0]?.[0] as ChatGeneration;
if (isBaseMessage(chatGeneration?.message)) {
this._emit(this.metadatas[runId], chatGeneration?.message, true);
this._emit(this.metadatas[runId], chatGeneration?.message, runId, true);
}
delete this.emittedChatModelRunIds[runId];
}
delete this.metadatas[runId];
delete this.stableMessageIdMap[runId];
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand Down Expand Up @@ -160,21 +173,21 @@ export class StreamMessagesHandler extends BaseCallbackHandler {
delete this.metadatas[runId];
if (metadata !== undefined) {
if (isBaseMessage(outputs)) {
this._emit(metadata, outputs, true);
this._emit(metadata, outputs, runId, true);
} else if (Array.isArray(outputs)) {
for (const value of outputs) {
if (isBaseMessage(value)) {
this._emit(metadata, value, true);
this._emit(metadata, value, runId, true);
}
}
} else if (outputs != null && typeof outputs === "object") {
for (const value of Object.values(outputs)) {
if (isBaseMessage(value)) {
this._emit(metadata, value, true);
this._emit(metadata, value, runId, true);
} else if (Array.isArray(value)) {
for (const item of value) {
if (isBaseMessage(item)) {
this._emit(metadata, item, true);
this._emit(metadata, item, runId, true);
}
}
}
Expand Down
70 changes: 31 additions & 39 deletions libs/langgraph/src/tests/pregel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8564,24 +8564,6 @@ graph TD;
tags: ["c_two_chat_model"],
},
],
[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

were these being emitted by mistake?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so? @jacoblee93 can clarify, but it doesn't make sense to emit the message twice

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool yea def don't want to emit things twice, can you point me to the line that was causing this?

Copy link
Contributor Author

@dqbd dqbd Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new _AnyIdAIMessageChunk({
content: "2",
}),
{
langgraph_step: 2,
langgraph_node: "c_two",
langgraph_triggers: ["c_one"],
langgraph_path: [PULL, "c_two"],
langgraph_checkpoint_ns: expect.stringMatching(/^p_two:.*\|c_two:.*/),
__pregel_resuming: false,
__pregel_task_id: expect.any(String),
checkpoint_ns: expect.stringMatching(/^p_two:/),
name: "c_two",
tags: ["graph:step:2"],
ls_stop: undefined,
},
],
[
new _AnyIdAIMessageChunk({
content: "x",
Expand Down Expand Up @@ -8737,27 +8719,6 @@ graph TD;
},
],
],
[
"messages",
[
new _AnyIdAIMessageChunk({
content: "2",
}),
{
langgraph_step: 2,
langgraph_node: "c_two",
langgraph_triggers: ["c_one"],
langgraph_path: [PULL, "c_two"],
langgraph_checkpoint_ns:
expect.stringMatching(/^p_two:.*\|c_two:.*/),
__pregel_resuming: false,
__pregel_task_id: expect.any(String),
checkpoint_ns: expect.stringMatching(/^p_two:/),
tags: ["graph:step:2"],
name: "c_two",
},
],
],
[
"messages",
[
Expand Down Expand Up @@ -9470,6 +9431,37 @@ graph TD;
expect(oneCount).toEqual(1);
expect(twoCount).toEqual(0);
});

it.each(["omit", "first-only", "always"] as const)(
"`messages` inherits message ID - %p",
async (streamMessageId) => {
const checkpointer = await createCheckpointer();

const graph = new StateGraph(MessagesAnnotation)
.addNode("one", async () => {
const model = new FakeChatModel({
responses: [new AIMessage({ id: "123", content: "Output" })],
streamMessageId,
});

const invoke = await model.invoke([new HumanMessage("Input")]);
return { messages: invoke };
})
.addEdge(START, "one")
.compile({ checkpointer });

const messages = await gatherIterator(
graph.stream(
{ messages: [] },
{ configurable: { thread_id: "1" }, streamMode: "messages" }
)
);

const messageIds = [...new Set(messages.map(([m]) => m.id))];
expect(messageIds).toHaveLength(1);
if (streamMessageId !== "omit") expect(messageIds[0]).toBe("123");
}
);
}

runPregelTests(() => new MemorySaverAssertImmutable());
45 changes: 37 additions & 8 deletions libs/langgraph/src/tests/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
/* eslint-disable import/no-extraneous-dependencies */
import assert from "node:assert";
import { expect } from "@jest/globals";
import { v4 as uuidv4 } from "uuid";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
BaseChatModel,
Expand Down Expand Up @@ -44,9 +45,16 @@ export class FakeChatModel extends BaseChatModel {

callCount = 0;

constructor(fields: FakeChatModelArgs) {
streamMessageId: "omit" | "first-only" | "always";

constructor(
fields: FakeChatModelArgs & {
streamMessageId?: "omit" | "first-only" | "always";
}
) {
super(fields);
this.responses = fields.responses;
this.streamMessageId = fields.streamMessageId ?? "omit";
}

_combineLLMOutput() {
Expand Down Expand Up @@ -91,14 +99,35 @@ export class FakeChatModel extends BaseChatModel {
runManager?: CallbackManagerForLLMRun
) {
const response = this.responses[this.callCount % this.responses.length];
for (const text of (response.content as string).split("")) {
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: text as string,
}),
text,

let isFirstChunk = true;
const completionId = response.id ?? uuidv4();

for (const content of (response.content as string).split("")) {
let id: string | undefined;
if (
this.streamMessageId === "always" ||
(this.streamMessageId === "first-only" && isFirstChunk)
) {
id = completionId;
}

const chunk = new ChatGenerationChunk({
message: new AIMessageChunk({ content, id }),
text: content,
});
await runManager?.handleLLMNewToken(text as string);

yield chunk;
await runManager?.handleLLMNewToken(
content,
undefined,
undefined,
undefined,
undefined,
{ chunk }
);

isFirstChunk = false;
}
this.callCount += 1;
}
Expand Down
Loading