From 5fb373b212e688e65ef39fab51f511089ce0c157 Mon Sep 17 00:00:00 2001 From: mrT23 Date: Sun, 17 Dec 2023 16:52:03 +0200 Subject: [PATCH] Refactor AI handler instantiation to use lazy initialization in PR tools --- pr_agent/agent/pr_agent.py | 6 ++++-- pr_agent/tools/pr_add_docs.py | 5 +++-- pr_agent/tools/pr_code_suggestions.py | 5 +++-- pr_agent/tools/pr_description.py | 5 +++-- pr_agent/tools/pr_generate_labels.py | 5 +++-- pr_agent/tools/pr_information_from_user.py | 5 +++-- pr_agent/tools/pr_questions.py | 5 +++-- pr_agent/tools/pr_reviewer.py | 5 +++-- pr_agent/tools/pr_update_changelog.py | 5 +++-- 9 files changed, 28 insertions(+), 18 deletions(-) diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index a6c7cf5ec..3eb26841c 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -1,4 +1,6 @@ import shlex +from functools import partial + from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler @@ -41,8 +43,8 @@ commands = list(command2class.keys()) class PRAgent: - def __init__(self, ai_handler: BaseAiHandler = LiteLLMAIHandler()): - self.ai_handler = ai_handler + def __init__(self, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): + self.ai_handler = ai_handler # will be initialized in run_action async def handle_request(self, pr_url, request, notify=None) -> bool: # First, apply repo specific settings if exists diff --git a/pr_agent/tools/pr_add_docs.py b/pr_agent/tools/pr_add_docs.py index a729233d3..d13a829d4 100644 --- a/pr_agent/tools/pr_add_docs.py +++ b/pr_agent/tools/pr_add_docs.py @@ -1,5 +1,6 @@ import copy import textwrap +from functools import partial from typing import Dict from jinja2 import Environment, StrictUndefined @@ -17,14 +18,14 @@ class PRAddDocs: def __init__(self, pr_url: str, cli_mode=False, args: list = None, - ai_handler: BaseAiHandler = LiteLLMAIHandler()): + ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) - self.ai_handler = ai_handler + self.ai_handler = ai_handler() self.patches_diff = None self.prediction = None self.cli_mode = cli_mode diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index 81e1ceabe..6b30a8a84 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -1,5 +1,6 @@ import copy import textwrap +from functools import partial from typing import Dict, List from jinja2 import Environment, StrictUndefined @@ -16,7 +17,7 @@ class PRCodeSuggestions: def __init__(self, pr_url: str, cli_mode=False, args: list = None, - ai_handler: BaseAiHandler = LiteLLMAIHandler()): + ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( @@ -33,7 +34,7 @@ def __init__(self, pr_url: str, cli_mode=False, args: list = None, else: num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions - self.ai_handler = ai_handler + self.ai_handler = ai_handler() self.patches_diff = None self.prediction = None self.cli_mode = cli_mode diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index 4915c5b68..95a5fc16d 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -1,5 +1,6 @@ import copy import re +from functools import partial from typing import List, Tuple from jinja2 import Environment, StrictUndefined @@ -17,7 +18,7 @@ class PRDescription: def __init__(self, pr_url: str, args: list = None, - ai_handler: BaseAiHandler = LiteLLMAIHandler()): + ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): """ Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model. @@ -38,7 +39,7 @@ def __init__(self, pr_url: str, args: list = None, get_settings().pr_description.enable_semantic_files_types = False # Initialize the AI handler - self.ai_handler = ai_handler + self.ai_handler = ai_handler() # Initialize the variables dictionary self.vars = { diff --git a/pr_agent/tools/pr_generate_labels.py b/pr_agent/tools/pr_generate_labels.py index 25e80a55b..213ddee45 100644 --- a/pr_agent/tools/pr_generate_labels.py +++ b/pr_agent/tools/pr_generate_labels.py @@ -1,5 +1,6 @@ import copy import re +from functools import partial from typing import List, Tuple from jinja2 import Environment, StrictUndefined @@ -17,7 +18,7 @@ class PRGenerateLabels: def __init__(self, pr_url: str, args: list = None, - ai_handler: BaseAiHandler = LiteLLMAIHandler()): + ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): """ Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels corresponding to the PR using an AI model. @@ -33,7 +34,7 @@ def __init__(self, pr_url: str, args: list = None, self.pr_id = self.git_provider.get_pr_id() # Initialize the AI handler - self.ai_handler = ai_handler + self.ai_handler = ai_handler() # Initialize the variables dictionary self.vars = { diff --git a/pr_agent/tools/pr_information_from_user.py b/pr_agent/tools/pr_information_from_user.py index a47d511be..1523d7373 100644 --- a/pr_agent/tools/pr_information_from_user.py +++ b/pr_agent/tools/pr_information_from_user.py @@ -1,4 +1,5 @@ import copy +from functools import partial from jinja2 import Environment, StrictUndefined @@ -14,12 +15,12 @@ class PRInformationFromUser: def __init__(self, pr_url: str, args: list = None, - ai_handler: BaseAiHandler = LiteLLMAIHandler()): + ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) - self.ai_handler = ai_handler + self.ai_handler = ai_handler() self.vars = { "title": self.git_provider.pr.title, "branch": self.git_provider.get_pr_branch(), diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index 5de3d7762..4d4995988 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -1,4 +1,5 @@ import copy +from functools import partial from jinja2 import Environment, StrictUndefined @@ -13,13 +14,13 @@ class PRQuestions: - def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()): + def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): question_str = self.parse_args(args) self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) - self.ai_handler = ai_handler + self.ai_handler = ai_handler() self.question_str = question_str self.vars = { "title": self.git_provider.pr.title, diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 24a40af31..6543496f7 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -1,6 +1,7 @@ import copy import datetime from collections import OrderedDict +from functools import partial from typing import List, Tuple import yaml @@ -24,7 +25,7 @@ class PRReviewer: The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. """ def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, - ai_handler: BaseAiHandler = LiteLLMAIHandler()): + ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): """ Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. @@ -47,7 +48,7 @@ def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, if self.is_answer and not self.git_provider.is_supported("get_issue_comments"): raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now") - self.ai_handler = ai_handler + self.ai_handler = ai_handler() self.patches_diff = None self.prediction = None diff --git a/pr_agent/tools/pr_update_changelog.py b/pr_agent/tools/pr_update_changelog.py index b8c6187f1..c7ffa6d83 100644 --- a/pr_agent/tools/pr_update_changelog.py +++ b/pr_agent/tools/pr_update_changelog.py @@ -1,5 +1,6 @@ import copy from datetime import date +from functools import partial from time import sleep from typing import Tuple @@ -18,7 +19,7 @@ class PRUpdateChangelog: - def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()): + def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( @@ -26,7 +27,7 @@ def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHan ) self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes self._get_changlog_file() # self.changelog_file_str - self.ai_handler = ai_handler + self.ai_handler = ai_handler() self.patches_diff = None self.prediction = None self.cli_mode = cli_mode