From 380449a7a924422c252c901943e924ae741b8d8a Mon Sep 17 00:00:00 2001 From: Chun Kang Lu Date: Mon, 21 Oct 2024 17:31:40 -0400 Subject: [PATCH] core: fix Image prompt template hardcoded template format (#27495) Fixes #27411 **Description:** Adds `template_format` to the `ImagePromptTemplate` class and updates passing in the `template_format` parameter from ChatPromptTemplate instead of the hardcoded "f-string". Also updated docs and typing related to `template_format` to be more up-to-date and specific. **Dependencies:** None **Add tests and docs**: Added unit tests to validate fix. Needed to update `test_chat` snapshot due to adding new attribute `template_format` in `ImagePromptTemplate`. --------- Co-authored-by: Vadym Barda --- libs/core/extended_testing_deps.txt | 1 + libs/core/langchain_core/prompts/chat.py | 34 +++++---- .../prompts/few_shot_with_templates.py | 6 +- libs/core/langchain_core/prompts/image.py | 11 ++- libs/core/langchain_core/prompts/prompt.py | 12 ++-- libs/core/langchain_core/prompts/string.py | 4 +- .../core/langchain_core/prompts/structured.py | 4 +- .../prompts/__snapshots__/test_chat.ambr | 7 ++ .../tests/unit_tests/prompts/test_chat.py | 72 +++++++++++++++++++ .../tests/unit_tests/prompts/test_prompt.py | 5 +- 10 files changed, 131 insertions(+), 25 deletions(-) diff --git a/libs/core/extended_testing_deps.txt b/libs/core/extended_testing_deps.txt index 5ad9c8930daf9..0e216a61157d8 100644 --- a/libs/core/extended_testing_deps.txt +++ b/libs/core/extended_testing_deps.txt @@ -1 +1,2 @@ jinja2>=3,<4 +mustache>=0.1.4,<1 diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 111aed89b3c97..1629962ba1333 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -8,7 +8,6 @@ from typing import ( Annotated, Any, - Literal, Optional, TypedDict, TypeVar, @@ -40,7 +39,11 @@ from langchain_core.prompts.base import BasePromptTemplate from langchain_core.prompts.image import ImagePromptTemplate from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.prompts.string import StringPromptTemplate, get_template_variables +from langchain_core.prompts.string import ( + PromptTemplateFormat, + StringPromptTemplate, + get_template_variables, +) from langchain_core.utils import get_colored_text from langchain_core.utils.interactive_env import is_interactive_env @@ -296,7 +299,7 @@ def get_lc_namespace(cls) -> list[str]: def from_template( cls: type[MessagePromptTemplateT], template: str, - template_format: str = "f-string", + template_format: PromptTemplateFormat = "f-string", partial_variables: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> MessagePromptTemplateT: @@ -486,7 +489,7 @@ def get_lc_namespace(cls) -> list[str]: def from_template( cls: type[_StringImageMessagePromptTemplateT], template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]], - template_format: str = "f-string", + template_format: PromptTemplateFormat = "f-string", *, partial_variables: Optional[dict[str, Any]] = None, **kwargs: Any, @@ -495,7 +498,8 @@ def from_template( Args: template: a template. - template_format: format of the template. Defaults to "f-string". + template_format: format of the template. + Options are: 'f-string', 'mustache', 'jinja2'. Defaults to "f-string". partial_variables: A dictionary of variables that can be used too partially. Defaults to None. **kwargs: keyword arguments to pass to the constructor. @@ -533,7 +537,7 @@ def from_template( img_template = cast(_ImageTemplateParam, tmpl)["image_url"] input_variables = [] if isinstance(img_template, str): - vars = get_template_variables(img_template, "f-string") + vars = get_template_variables(img_template, template_format) if vars: if len(vars) > 1: msg = ( @@ -545,7 +549,9 @@ def from_template( input_variables = [vars[0]] img_template = {"url": img_template} img_template_obj = ImagePromptTemplate( - input_variables=input_variables, template=img_template + input_variables=input_variables, + template=img_template, + template_format=template_format, ) elif isinstance(img_template, dict): img_template = dict(img_template) @@ -553,11 +559,13 @@ def from_template( if key in img_template: input_variables.extend( get_template_variables( - img_template[key], "f-string" + img_template[key], template_format ) ) img_template_obj = ImagePromptTemplate( - input_variables=input_variables, template=img_template + input_variables=input_variables, + template=img_template, + template_format=template_format, ) else: msg = f"Invalid image template: {tmpl}" @@ -943,7 +951,7 @@ def __init__( self, messages: Sequence[MessageLikeRepresentation], *, - template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", + template_format: PromptTemplateFormat = "f-string", **kwargs: Any, ) -> None: """Create a chat prompt template from a variety of message formats. @@ -1160,7 +1168,7 @@ def from_strings( def from_messages( cls, messages: Sequence[MessageLikeRepresentation], - template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", + template_format: PromptTemplateFormat = "f-string", ) -> ChatPromptTemplate: """Create a chat prompt template from a variety of message formats. @@ -1354,7 +1362,7 @@ def pretty_repr(self, html: bool = False) -> str: def _create_template_from_message_type( message_type: str, template: Union[str, list], - template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", + template_format: PromptTemplateFormat = "f-string", ) -> BaseMessagePromptTemplate: """Create a message prompt template from a message type and template string. @@ -1426,7 +1434,7 @@ def _create_template_from_message_type( def _convert_to_message( message: MessageLikeRepresentation, - template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", + template_format: PromptTemplateFormat = "f-string", ) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]: """Instantiate a message from a variety of message formats. diff --git a/libs/core/langchain_core/prompts/few_shot_with_templates.py b/libs/core/langchain_core/prompts/few_shot_with_templates.py index f293aae1c80f8..da53d5b8c59d1 100644 --- a/libs/core/langchain_core/prompts/few_shot_with_templates.py +++ b/libs/core/langchain_core/prompts/few_shot_with_templates.py @@ -9,6 +9,7 @@ from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.string import ( DEFAULT_FORMATTER_MAPPING, + PromptTemplateFormat, StringPromptTemplate, ) @@ -36,8 +37,9 @@ class FewShotPromptWithTemplates(StringPromptTemplate): prefix: Optional[StringPromptTemplate] = None """A PromptTemplate to put before the examples.""" - template_format: str = "f-string" - """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" + template_format: PromptTemplateFormat = "f-string" + """The format of the prompt template. + Options are: 'f-string', 'jinja2', 'mustache'.""" validate_template: bool = False """Whether or not to try validating the template.""" diff --git a/libs/core/langchain_core/prompts/image.py b/libs/core/langchain_core/prompts/image.py index a75a5eece0f91..9336e20f60ac5 100644 --- a/libs/core/langchain_core/prompts/image.py +++ b/libs/core/langchain_core/prompts/image.py @@ -4,6 +4,10 @@ from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.prompts.string import ( + DEFAULT_FORMATTER_MAPPING, + PromptTemplateFormat, +) from langchain_core.runnables import run_in_executor from langchain_core.utils import image as image_utils @@ -13,6 +17,9 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]): template: dict = Field(default_factory=dict) """Template for the prompt.""" + template_format: PromptTemplateFormat = "f-string" + """The format of the prompt template. + Options are: 'f-string', 'mustache', 'jinja2'.""" def __init__(self, **kwargs: Any) -> None: if "input_variables" not in kwargs: @@ -85,7 +92,9 @@ def format( formatted = {} for k, v in self.template.items(): if isinstance(v, str): - formatted[k] = v.format(**kwargs) + formatted[k] = DEFAULT_FORMATTER_MAPPING[self.template_format]( + v, **kwargs + ) else: formatted[k] = v url = kwargs.get("url") or formatted.get("url") diff --git a/libs/core/langchain_core/prompts/prompt.py b/libs/core/langchain_core/prompts/prompt.py index 5c52ef36d076c..325ee067ecaf9 100644 --- a/libs/core/langchain_core/prompts/prompt.py +++ b/libs/core/langchain_core/prompts/prompt.py @@ -4,12 +4,13 @@ import warnings from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, Optional, Union from pydantic import BaseModel, model_validator from langchain_core.prompts.string import ( DEFAULT_FORMATTER_MAPPING, + PromptTemplateFormat, StringPromptTemplate, check_valid_template, get_template_variables, @@ -24,7 +25,8 @@ class PromptTemplate(StringPromptTemplate): A prompt template consists of a string template. It accepts a set of parameters from the user that can be used to generate a prompt for a language model. - The template can be formatted using either f-strings (default) or jinja2 syntax. + The template can be formatted using either f-strings (default), jinja2, + or mustache syntax. *Security warning*: Prefer using `template_format="f-string"` instead of @@ -67,7 +69,7 @@ def get_lc_namespace(cls) -> list[str]: template: str """The prompt template.""" - template_format: Literal["f-string", "mustache", "jinja2"] = "f-string" + template_format: PromptTemplateFormat = "f-string" """The format of the prompt template. Options are: 'f-string', 'mustache', 'jinja2'.""" @@ -248,7 +250,7 @@ def from_template( cls, template: str, *, - template_format: str = "f-string", + template_format: PromptTemplateFormat = "f-string", partial_variables: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> PromptTemplate: @@ -270,7 +272,7 @@ def from_template( Args: template: The template to load. template_format: The format of the template. Use `jinja2` for jinja2, - and `f-string` or None for f-strings. + `mustache` for mustache, and `f-string` for f-strings. Defaults to `f-string`. partial_variables: A dictionary of variables that can be used to partially fill in the template. For example, if the template is diff --git a/libs/core/langchain_core/prompts/string.py b/libs/core/langchain_core/prompts/string.py index 01c19486b6760..0040f4b712279 100644 --- a/libs/core/langchain_core/prompts/string.py +++ b/libs/core/langchain_core/prompts/string.py @@ -5,7 +5,7 @@ import warnings from abc import ABC from string import Formatter -from typing import Any, Callable +from typing import Any, Callable, Literal from pydantic import BaseModel, create_model @@ -16,6 +16,8 @@ from langchain_core.utils.formatting import formatter from langchain_core.utils.interactive_env import is_interactive_env +PromptTemplateFormat = Literal["f-string", "mustache", "jinja2"] + def jinja2_formatter(template: str, /, **kwargs: Any) -> str: """Format a template using jinja2. diff --git a/libs/core/langchain_core/prompts/structured.py b/libs/core/langchain_core/prompts/structured.py index 360f14ccf2fde..543dbb57e5d90 100644 --- a/libs/core/langchain_core/prompts/structured.py +++ b/libs/core/langchain_core/prompts/structured.py @@ -2,7 +2,6 @@ from typing import ( Any, Callable, - Literal, Optional, Union, ) @@ -15,6 +14,7 @@ ChatPromptTemplate, MessageLikeRepresentation, ) +from langchain_core.prompts.string import PromptTemplateFormat from langchain_core.runnables.base import ( Other, Runnable, @@ -38,7 +38,7 @@ def __init__( schema_: Optional[Union[dict, type[BaseModel]]] = None, *, structured_output_kwargs: Optional[dict[str, Any]] = None, - template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", + template_format: PromptTemplateFormat = "f-string", **kwargs: Any, ) -> None: schema_ = schema_ or kwargs.pop("schema") diff --git a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr index 7e35bd6c46548..8e5e5c61ef48c 100644 --- a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr +++ b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr @@ -3119,6 +3119,7 @@ 'template': dict({ 'url': 'data:image/jpeg;base64,{my_image}', }), + 'template_format': 'f-string', }), 'lc': 1, 'name': 'ImagePromptTemplate', @@ -3138,6 +3139,7 @@ 'template': dict({ 'url': 'data:image/jpeg;base64,{my_image}', }), + 'template_format': 'f-string', }), 'lc': 1, 'name': 'ImagePromptTemplate', @@ -3157,6 +3159,7 @@ 'template': dict({ 'url': '{my_other_image}', }), + 'template_format': 'f-string', }), 'lc': 1, 'name': 'ImagePromptTemplate', @@ -3177,6 +3180,7 @@ 'detail': 'medium', 'url': '{my_other_image}', }), + 'template_format': 'f-string', }), 'lc': 1, 'name': 'ImagePromptTemplate', @@ -3195,6 +3199,7 @@ 'template': dict({ 'url': 'https://www.langchain.com/image.png', }), + 'template_format': 'f-string', }), 'lc': 1, 'name': 'ImagePromptTemplate', @@ -3213,6 +3218,7 @@ 'template': dict({ 'url': '', }), + 'template_format': 'f-string', }), 'lc': 1, 'name': 'ImagePromptTemplate', @@ -3231,6 +3237,7 @@ 'template': dict({ 'url': '', }), + 'template_format': 'f-string', }), 'lc': 1, 'name': 'ImagePromptTemplate', diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 72056f6c5a22c..95cbdc5a30c68 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -31,6 +31,7 @@ SystemMessagePromptTemplate, _convert_to_message, ) +from langchain_core.prompts.string import PromptTemplateFormat from tests.unit_tests.pydantic_utils import _normalize_schema @@ -298,6 +299,77 @@ def test_chat_prompt_template_from_messages_mustache() -> None: ] +@pytest.mark.requires("jinja2") +def test_chat_prompt_template_from_messages_jinja2() -> None: + template = ChatPromptTemplate.from_messages( + [ + ("system", "You are a helpful AI bot. Your name is {{ name }}."), + ("human", "Hello, how are you doing?"), + ("ai", "I'm doing well, thanks!"), + ("human", "{{ user_input }}"), + ], + "jinja2", + ) + + messages = template.format_messages(name="Bob", user_input="What is your name?") + + assert messages == [ + SystemMessage( + content="You are a helpful AI bot. Your name is Bob.", additional_kwargs={} + ), + HumanMessage( + content="Hello, how are you doing?", additional_kwargs={}, example=False + ), + AIMessage( + content="I'm doing well, thanks!", additional_kwargs={}, example=False + ), + HumanMessage(content="What is your name?", additional_kwargs={}, example=False), + ] + + +@pytest.mark.requires("jinja2") +@pytest.mark.requires("mustache") +@pytest.mark.parametrize( + "template_format,image_type_placeholder,image_data_placeholder", + [ + ("f-string", "{image_type}", "{image_data}"), + ("mustache", "{{image_type}}", "{{image_data}}"), + ("jinja2", "{{ image_type }}", "{{ image_data }}"), + ], +) +def test_chat_prompt_template_image_prompt_from_message( + template_format: PromptTemplateFormat, + image_type_placeholder: str, + image_data_placeholder: str, +) -> None: + prompt = { + "type": "image_url", + "image_url": { + "url": f"data:{image_type_placeholder};base64, {image_data_placeholder}", + "detail": "low", + }, + } + + template = ChatPromptTemplate.from_messages( + [("human", [prompt])], template_format=template_format + ) + assert template.format_messages( + image_type="image/png", image_data="base64data" + ) == [ + HumanMessage( + content=[ + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64, base64data", + "detail": "low", + }, + } + ] + ) + ] + + def test_chat_prompt_template_with_messages( messages: list[BaseMessagePromptTemplate], ) -> None: diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index cf256a8ae1d30..d56654d874d5b 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -8,6 +8,7 @@ from syrupy import SnapshotAssertion from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.prompts.string import PromptTemplateFormat from langchain_core.tracers.run_collector import RunCollectorCallbackHandler from tests.unit_tests.pydantic_utils import _normalize_schema @@ -610,7 +611,9 @@ async def test_prompt_ainvoke_with_metadata() -> None: ) @pytest.mark.parametrize("template_format", ["f-string", "mustache"]) def test_prompt_falsy_vars( - template_format: str, value: Any, expected: Union[str, dict[str, str]] + template_format: PromptTemplateFormat, + value: Any, + expected: Union[str, dict[str, str]], ) -> None: # each line is value, f-string, mustache if template_format == "f-string":