-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathcve_checklist_node.py
158 lines (124 loc) · 6.64 KB
/
cve_checklist_node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import logging
from morpheus_llm.llm import LLMLambdaNode
from morpheus_llm.llm import LLMNode
from morpheus_llm.llm.nodes.llm_generate_node import LLMGenerateNode
from morpheus_llm.llm.nodes.prompt_template_node import PromptTemplateNode
from morpheus_llm.llm.services.llm_service import LLMClient
from ..utils.prompting import MOD_FEW_SHOT
from ..utils.prompting import additional_intel_prompting
from ..utils.prompting import get_mod_examples
from ..utils.string_utils import attempt_fix_list_string
logger = logging.getLogger(__name__)
DEFAULT_CHECKLIST_PROMPT = MOD_FEW_SHOT.format(examples=get_mod_examples())
cve_prompt2 = """Parse the following numbered checklist into a python list in the format ["x", "y", "z"], a comma separated list surrounded by square braces: {{template}}"""
async def _parse_list(text: list[str]) -> list[list[str]]:
"""
Asynchronously parse a list of strings, each representing a list, into a list of lists.
Parameters
----------
text : list of str
A list of strings, each intended to be parsed into a list.
Returns
-------
list of lists of str
A list of lists, parsed from the input strings.
Raises
------
ValueError
If the string cannot be parsed into a list or if the parsed object is not a list.
Notes
-----
This function tries to fix strings that represent lists with unescaped quotes by calling
`attempt_fix_list_string` and then uses `ast.literal_eval` for safe parsing of the string into a list.
It ensures that each element of the parsed list is actually a list and will raise an error if not.
"""
return_val = []
for checklist_num, x in enumerate(text):
try:
# Remove any text not enclosed by square brackets
x = x[x.find('['):x.rfind(']') + 1]
# Remove newline characters that can cause incorrect string escaping in the next step
x = x.replace("\n", "")
# Ensure backslashes are escaped
x = x.replace("\\", "\\\\")
# Try to do some very basic string cleanup to fix unescaped quotes
x = attempt_fix_list_string(x)
# Only proceed if the input is a valid Python literal
# This isn't really dangerous, literal_eval only evaluates a small subset of python
current = ast.literal_eval(x)
# Ensure that the parsed data is a list
if not isinstance(current, list):
raise ValueError(f"Input is not a list: {x}")
# Process the list items
for i in range(len(current)):
if (isinstance(current[i], list) and len(current[i]) == 1):
current[i] = current[i][0]
return_val.append(current)
except (ValueError, SyntaxError) as e:
# Handle the error, log it, or re-raise it with additional context
raise ValueError(f"Failed to parse input for checklist number {checklist_num}: {x}. Error: {e}")
return return_val
class CVEChecklistNode(LLMNode):
"""
A node that orchestrates the process of generating a checklist for CVE (Common Vulnerabilities and Exposures) items.
It integrates various nodes that handle CVE lookup, prompting, generation, and parsing to produce an actionable checklist.
"""
def __init__(self, *, prompt: str,
llm_client: LLMClient,
enable_llm_list_parsing: bool = False):
"""
Initialize the CVEChecklistNode with optional caching and a vulnerability endpoint retriever.
Parameters
----------
model_name : str, optional
The name of the language model to be used for generating text, by default "gpt-3.5-turbo".
cache_dir : str, optional
The directory where the node's cache should be stored. If None, caching is not used.
vuln_endpoint_retriever : object, optional
An instance of a vulnerability endpoint retriever. If None, defaults to `NISTCVERetriever`.
"""
super().__init__()
if not prompt:
prompt = DEFAULT_CHECKLIST_PROMPT
intel = (
additional_intel_prompting +
"\n\nIf a vulnerable function or method is mentioned in the CVE description, ensure the first checklist item verifies whether this function or method is being called from the code or used by the code."
"\nThe vulnerable version of the vulnerable package is already verified to be installed within the container. Check only the other factors that affect exploitability, no need to verify version again."
)
cve_prompt1 = (
prompt
+ intel
)
# Add a node to create a prompt for CVE checklist generation based on the CVE details obtained from the lookup
# node
self.add_node("checklist_prompt",
inputs=[("*", "*")],
node=PromptTemplateNode(template=cve_prompt1, template_format="jinja"))
gen_node_1 = LLMGenerateNode(llm_client=llm_client)
self.add_node("chat1", inputs=["/checklist_prompt"], node=gen_node_1)
if enable_llm_list_parsing:
# Add a node to parse the generated response into a format suitable for a secondary checklist prompt
self.add_node("parse_checklist_prompt",
inputs=["/chat1"],
node=PromptTemplateNode(template=cve_prompt2, template_format="jinja"))
# Configure a second node for generating a follow-up response based on the parsed checklist prompt
gen_node_2 = LLMGenerateNode(llm_client=llm_client)
self.add_node("chat2", inputs=[("/parse_checklist_prompt", "prompt")], node=gen_node_2)
checklist_prompts = ["/chat2"] if enable_llm_list_parsing else ["/chat1"]
# Add an output parser node to process the final generated checklist into a structured list
self.add_node("output_parser", inputs=checklist_prompts, node=LLMLambdaNode(_parse_list), is_output=True)