diff --git a/torch_geometric/nn/nlp/txt2kg.py b/torch_geometric/nn/nlp/txt2kg.py index 0e5227392925..1096ffc4bdad 100644 --- a/torch_geometric/nn/nlp/txt2kg.py +++ b/torch_geometric/nn/nlp/txt2kg.py @@ -9,7 +9,7 @@ CLIENT = None GLOBAL_NIM_KEY = "" -SYSTEM_PROMPT = "Please convert the above text into a list of knowledge triples with the form ('entity', 'relation', 'entity'). Seperate each with a new line. Do not output anything else. Try to focus on key triples that form a connected graph.”" +SYSTEM_PROMPT = "Please convert the above text into a list of knowledge triples with the form ('entity', 'relation', 'entity'). Seperate each with a new line. Do not output anything else. Try to focus on key triples that form a connected graph.”" #noqa class TXT2KG(): @@ -30,13 +30,16 @@ class TXT2KG(): Args: NVIDIA_NIM_MODEL : str, optional - The name of the NVIDIA NIM model to use (default: "nvidia/llama-3.1-nemotron-70b-instruct"). + The name of the NVIDIA NIM model to use. + (default: "nvidia/llama-3.1-nemotron-70b-instruct"). NVIDIA_API_KEY : str, optional The API key for accessing NVIDIA's NIM models (default: ""). local_LM : bool, optional - A flag indicating whether a local Language Model (LM) should be used (default: False). + A flag indicating whether a local Language Model (LM) + should be used (default: False). chunk_size : int, optional - The size of the chunks in which the text data is processed (default: 512). + The size of the chunks in which the text data is processed + (default: 512). """ def __init__( self, @@ -89,7 +92,7 @@ def _chunk_to_triples_str_local(self, txt: str) -> str: # for debug self.total_chars_parsed += len(txt) self.time_to_parse += round(time.time() - chunk_start_time, 2) - self.avg_chars_parsed_per_sec = self.total_chars_parsed / self.time_to_parse + self.avg_chars_parsed_per_sec = self.total_chars_parsed / self.time_to_parse #noqa return out_str def add_doc_2_KG( @@ -97,7 +100,7 @@ def add_doc_2_KG( txt: str, QA_pair: Optional[Tuple[str, str]] = None, ) -> None: - """Add a document to the Knowledge Graph (KG) by extracting triples from the text. + """Add a document to the Knowledge Graph (KG). Args: txt (str): The text to extract triples from. @@ -109,8 +112,8 @@ def add_doc_2_KG( - None """ # Ensure NVIDIA_API_KEY is set before proceeding - assert self.NVIDIA_API_KEY != '', "Please init TXT2KG w/ NVIDIA_API_KEY or set local_lm flag to True" - + assert self.NVIDIA_API_KEY != '', \ + "Please init TXT2KG w/ NVIDIA_API_KEY or set local_lm=True" if QA_pair: # QA_pairs should be unique keys, so check if it already exists in the KG assert QA_pair not in self.relevant_triples.keys()