-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathcve_langchain_agent_node.py
184 lines (140 loc) · 7.35 KB
/
cve_langchain_agent_node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import typing
from langchain_core.exceptions import OutputParserException
from morpheus_llm.llm import LLMContext
from morpheus_llm.llm import LLMNodeBase
from ..data_models.output import AgentIntermediateStep
from ..utils.data_utils import to_json
logger = logging.getLogger(__name__)
if typing.TYPE_CHECKING:
from langchain.agents import AgentExecutor
class CVELangChainAgentNode(LLMNodeBase):
"""
LangChainAgentNode which stores the VDB names in the metadata.
Parameters
----------
agent_executor : AgentExecutor
The agent executor to use to execute.
vdb_names : tuple[str, str]
Name of the VDBs to load from the input.
"""
def __init__(self,
*,
create_agent_executor_fn: "typing.Callable[[LLMContext], AgentExecutor]",
replace_exceptions: bool = False,
replace_exceptions_value: typing.Optional[str] = None):
super().__init__()
self._create_agent_executor_fn = create_agent_executor_fn
self._replace_exceptions = replace_exceptions
self._replace_exceptions_value = replace_exceptions_value
self._input_names = ["input"]
def get_input_names(self):
return self._input_names
@staticmethod
def _is_all_lists(data: dict[str, typing.Any]) -> bool:
if (len(data) == 0):
return False
return all(isinstance(v, list) for v in data.values())
@staticmethod
def _transform_dict_of_lists(data: dict[str, typing.Any]) -> list[dict[str, typing.Any]]:
return [dict(zip(data, t)) for t in zip(*data.values())]
@staticmethod
def _parse_intermediate_step(step: tuple[typing.Any, typing.Any]) -> dict[str, typing.Any]:
"""
Parse an agent intermediate step into an AgentIntermediateStep object. Return the dictionary representation for
compatibility with cudf.
"""
if len(step) != 2:
raise ValueError(f"Expected 2 values in each intermediate step but got {len(step)}.")
else:
action, output = step
# Flatten tool output to a JSON string for compatibility with cudf
output_json = to_json(output)
return AgentIntermediateStep(tool_name=action.tool,
action_log=action.log,
tool_input=action.tool_input,
tool_output=output_json).model_dump()
def _postprocess_results(self, results: list[list[dict]]) -> tuple[list[list[str]], list[list[list]]]:
"""
Post-process results into lists of outputs and intermediate steps. Replace exceptions with placholder values if
self._replace_exceptions = True.
"""
outputs = [[] for _ in range(len(results))]
intermediate_steps = [[] for _ in range(len(results))]
for i, answer_list in enumerate(results):
for j, answer in enumerate(answer_list):
# Handle exceptions returned by the agent
# OutputParserException is not a subclass of Exception, so we need to check for it separately
if isinstance(answer, (OutputParserException, Exception)):
if self._replace_exceptions:
# If the agent encounters a parsing error or a server error after retries, replace the error
# with default values to prevent the pipeline from crashing
outputs[i].append(self._replace_exceptions_value)
intermediate_steps[i].append(None)
logger.warning(f"Exception encountered in result[{i}][{j}]: {answer}. "
f"Replacing with default output: \"{self._replace_exceptions_value}\" "
"and intermediate_steps: None")
# For successful agent responses, extract the output, and intermediate steps if available
else:
outputs[i].append(answer["output"])
# intermediate_steps availability depends on run_config.engine.agent.return_intermediate_steps
if "intermediate_steps" in answer:
intermediate_steps[i].append(
[self._parse_intermediate_step(step) for step in answer["intermediate_steps"]])
else:
intermediate_steps[i].append(None)
return outputs, intermediate_steps
async def _run_single(self,
agent: "AgentExecutor",
metadata: dict[str, typing.Any] = None,
**kwargs) -> dict[str, typing.Any]:
all_lists = self._is_all_lists(kwargs)
# Check if all values are a list
if all_lists:
# Transform from dict[str, list[Any]] to list[dict[str, Any]]
input_list = self._transform_dict_of_lists(kwargs)
# If all metadata values are lists of the same length and the same length as the input list
# then transform them the same way as the input list
if self._is_all_lists(metadata) and all(len(v) == len(input_list) for v in metadata.values()):
metadata_list = self._transform_dict_of_lists(metadata)
else:
metadata_list = [metadata] * len(input_list)
# Run multiple again
results_async = [
self._run_single(agent=agent, metadata=metadata_list[i], **x) for (i, x) in enumerate(input_list)
]
results = await asyncio.gather(*results_async, return_exceptions=True)
# # Transform from list[dict[str, Any]] to dict[str, list[Any]]
# results = {k: [x[k] for x in results] for k in results[0]}
return results
# We are not dealing with a list, so run single
try:
input_single = {"input": kwargs.pop("input")}
config = {"callbacks": agent.callbacks, "tags": agent.tags, "metadata": metadata}
return await agent.ainvoke(input=input_single, config=config, **kwargs)
except Exception as e:
logger.exception("Error running agent: %s", e)
return e
async def execute(self, context: LLMContext) -> LLMContext:
input_dict: dict = context.get_inputs() # type: ignore
metadata = {}
agent = self._create_agent_executor_fn(context)
results = await self._run_single(agent=agent, metadata=metadata, **input_dict)
outputs, intermediate_steps = self._postprocess_results(results)
context.set_output({"outputs": outputs, "intermediate_steps": intermediate_steps})
return context