diff --git a/libs/langgraph/langgraph/types.py b/libs/langgraph/langgraph/types.py index d608f94f1..369e2c2f1 100644 --- a/libs/langgraph/langgraph/types.py +++ b/libs/langgraph/langgraph/types.py @@ -16,6 +16,7 @@ TypeVar, Union, cast, + get_type_hints, ) from langchain_core.runnables import Runnable, RunnableConfig @@ -289,6 +290,8 @@ def _update_as_tuples(self) -> Sequence[tuple[str, Any]]: for t in self.update ): return self.update + elif hints := get_type_hints(type(self.update)): + return [(k, getattr(self.update, k)) for k in hints] elif self.update is not None: return [("__root__", self.update)] else: diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index a45e92e4e..eff8ab25e 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -9,6 +9,7 @@ from collections import Counter, deque from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager +from dataclasses import dataclass from random import randrange from typing import ( Annotated, @@ -5134,6 +5135,35 @@ def my_node(state: State): assert graph.invoke({"foo": ""}) == {"foo": "ab"} +def test_command_pydantic_dataclass() -> None: + from pydantic import BaseModel + + class PydanticState(BaseModel): + foo: str + + @dataclass + class DataclassState: + foo: str + + for State in (PydanticState, DataclassState): + + def node_a(state) -> Command[Literal["node_b"]]: + return Command( + update=State(foo="foo"), + goto="node_b", + ) + + def node_b(state): + return {"foo": state.foo + "bar"} + + builder = StateGraph(State) + builder.add_edge(START, "node_a") + builder.add_node(node_a) + builder.add_node(node_b) + graph = builder.compile() + assert graph.invoke(State(foo="")) == {"foo": "foobar"} + + @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_command_with_static_breakpoints( request: pytest.FixtureRequest, checkpointer_name: str