From 28266fa8a58119ff77dbb616439576a2324b303d Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Fri, 10 Jan 2025 17:16:51 +0100 Subject: [PATCH] Fix aconfig_with_context asyncio.Event is not thread-safe so it must be created in the asyncio thread --- .../langchain_core/beta/runnables/context.py | 37 ++++++++++++------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/libs/core/langchain_core/beta/runnables/context.py b/libs/core/langchain_core/beta/runnables/context.py index 398ad488b8942..c6e26df2e3d8d 100644 --- a/libs/core/langchain_core/beta/runnables/context.py +++ b/libs/core/langchain_core/beta/runnables/context.py @@ -65,20 +65,11 @@ def _key_from_id(id_: str) -> str: def _config_with_context( config: RunnableConfig, - steps: list[Runnable], + context_specs: list[tuple[ConfigurableFieldSpec, int]], setter: Callable, getter: Callable, event_cls: Union[type[threading.Event], type[asyncio.Event]], ) -> RunnableConfig: - if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})): - return config - - context_specs = [ - (spec, i) - for i, step in enumerate(steps) - for spec in step.config_specs - if spec.id.startswith(CONTEXT_CONFIG_PREFIX) - ] grouped_by_key = { key: list(group) for key, group in groupby( @@ -134,8 +125,17 @@ async def aconfig_with_context( Returns: The patched runnable config. """ - return await asyncio.to_thread( - _config_with_context, config, steps, _asetter, _agetter, asyncio.Event + if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})): + return config + + context_specs = [ + (spec, i) + for i, step in enumerate(steps) + for spec in await asyncio.to_thread(getattr, step, "config_specs") + if spec.id.startswith(CONTEXT_CONFIG_PREFIX) + ] + return _config_with_context( + config, context_specs, _asetter, _agetter, asyncio.Event ) @@ -152,7 +152,18 @@ def config_with_context( Returns: The patched runnable config. """ - return _config_with_context(config, steps, _setter, _getter, threading.Event) + if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})): + return config + + context_specs = [ + (spec, i) + for i, step in enumerate(steps) + for spec in step.config_specs + if spec.id.startswith(CONTEXT_CONFIG_PREFIX) + ] + return _config_with_context( + config, context_specs, _setter, _getter, threading.Event + ) @beta()