Skip to content

Commit

Permalink
langgraph: add support for BaseModel updates to Command (#2747)
Browse files Browse the repository at this point in the history
Simple update that adds support for the `update` attribute of the
`Command` class to support Pydantic `BaseModel` type.

LangGraph already supports [Pydantic models for graph
states](https://langchain-ai.github.io/langgraph/how-tos/state-model/).

Extending support to the `update` attribute allows users to pass custom
BaseModel instances. Additionally, updates defined as `BaseModel` types
are type-validated when created.

#2804

---------

Co-authored-by: vbarda <[email protected]>
  • Loading branch information
larsenweigle and vbarda authored Jan 23, 2025
1 parent 38bbe67 commit 3955225
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
3 changes: 3 additions & 0 deletions libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TypeVar,
Union,
cast,
get_type_hints,
)

from langchain_core.runnables import Runnable, RunnableConfig
Expand Down Expand Up @@ -289,6 +290,8 @@ def _update_as_tuples(self) -> Sequence[tuple[str, Any]]:
for t in self.update
):
return self.update
elif hints := get_type_hints(type(self.update)):
return [(k, getattr(self.update, k)) for k in hints]
elif self.update is not None:
return [("__root__", self.update)]
else:
Expand Down
30 changes: 30 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import Counter, deque
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass
from random import randrange
from typing import (
Annotated,
Expand Down Expand Up @@ -5134,6 +5135,35 @@ def my_node(state: State):
assert graph.invoke({"foo": ""}) == {"foo": "ab"}


def test_command_pydantic_dataclass() -> None:
from pydantic import BaseModel

class PydanticState(BaseModel):
foo: str

@dataclass
class DataclassState:
foo: str

for State in (PydanticState, DataclassState):

def node_a(state) -> Command[Literal["node_b"]]:
return Command(
update=State(foo="foo"),
goto="node_b",
)

def node_b(state):
return {"foo": state.foo + "bar"}

builder = StateGraph(State)
builder.add_edge(START, "node_a")
builder.add_node(node_a)
builder.add_node(node_b)
graph = builder.compile()
assert graph.invoke(State(foo="")) == {"foo": "foobar"}


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_command_with_static_breakpoints(
request: pytest.FixtureRequest, checkpointer_name: str
Expand Down

0 comments on commit 3955225

Please sign in to comment.