Skip to content

Commit

Permalink
core: fix Image prompt template hardcoded template format (#27495)
Browse files Browse the repository at this point in the history
Fixes #27411 

**Description:** Adds `template_format` to the `ImagePromptTemplate`
class and updates passing in the `template_format` parameter from
ChatPromptTemplate instead of the hardcoded "f-string".
Also updated docs and typing related to `template_format` to be more
up-to-date and specific.

**Dependencies:** None

**Add tests and docs**: Added unit tests to validate fix. Needed to
update `test_chat` snapshot due to adding new attribute
`template_format` in `ImagePromptTemplate`.

---------

Co-authored-by: Vadym Barda <[email protected]>
  • Loading branch information
chunkanglu and vbarda authored Oct 21, 2024
1 parent 403c0ea commit 380449a
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 25 deletions.
1 change: 1 addition & 0 deletions libs/core/extended_testing_deps.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
jinja2>=3,<4
mustache>=0.1.4,<1
34 changes: 21 additions & 13 deletions libs/core/langchain_core/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import (
Annotated,
Any,
Literal,
Optional,
TypedDict,
TypeVar,
Expand Down Expand Up @@ -40,7 +39,11 @@
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
from langchain_core.prompts.string import (
PromptTemplateFormat,
StringPromptTemplate,
get_template_variables,
)
from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env

Expand Down Expand Up @@ -296,7 +299,7 @@ def get_lc_namespace(cls) -> list[str]:
def from_template(
cls: type[MessagePromptTemplateT],
template: str,
template_format: str = "f-string",
template_format: PromptTemplateFormat = "f-string",
partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> MessagePromptTemplateT:
Expand Down Expand Up @@ -486,7 +489,7 @@ def get_lc_namespace(cls) -> list[str]:
def from_template(
cls: type[_StringImageMessagePromptTemplateT],
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template_format: str = "f-string",
template_format: PromptTemplateFormat = "f-string",
*,
partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any,
Expand All @@ -495,7 +498,8 @@ def from_template(
Args:
template: a template.
template_format: format of the template. Defaults to "f-string".
template_format: format of the template.
Options are: 'f-string', 'mustache', 'jinja2'. Defaults to "f-string".
partial_variables: A dictionary of variables that can be used too partially.
Defaults to None.
**kwargs: keyword arguments to pass to the constructor.
Expand Down Expand Up @@ -533,7 +537,7 @@ def from_template(
img_template = cast(_ImageTemplateParam, tmpl)["image_url"]
input_variables = []
if isinstance(img_template, str):
vars = get_template_variables(img_template, "f-string")
vars = get_template_variables(img_template, template_format)
if vars:
if len(vars) > 1:
msg = (
Expand All @@ -545,19 +549,23 @@ def from_template(
input_variables = [vars[0]]
img_template = {"url": img_template}
img_template_obj = ImagePromptTemplate(
input_variables=input_variables, template=img_template
input_variables=input_variables,
template=img_template,
template_format=template_format,
)
elif isinstance(img_template, dict):
img_template = dict(img_template)
for key in ["url", "path", "detail"]:
if key in img_template:
input_variables.extend(
get_template_variables(
img_template[key], "f-string"
img_template[key], template_format
)
)
img_template_obj = ImagePromptTemplate(
input_variables=input_variables, template=img_template
input_variables=input_variables,
template=img_template,
template_format=template_format,
)
else:
msg = f"Invalid image template: {tmpl}"
Expand Down Expand Up @@ -943,7 +951,7 @@ def __init__(
self,
messages: Sequence[MessageLikeRepresentation],
*,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
template_format: PromptTemplateFormat = "f-string",
**kwargs: Any,
) -> None:
"""Create a chat prompt template from a variety of message formats.
Expand Down Expand Up @@ -1160,7 +1168,7 @@ def from_strings(
def from_messages(
cls,
messages: Sequence[MessageLikeRepresentation],
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
template_format: PromptTemplateFormat = "f-string",
) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats.
Expand Down Expand Up @@ -1354,7 +1362,7 @@ def pretty_repr(self, html: bool = False) -> str:
def _create_template_from_message_type(
message_type: str,
template: Union[str, list],
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
template_format: PromptTemplateFormat = "f-string",
) -> BaseMessagePromptTemplate:
"""Create a message prompt template from a message type and template string.
Expand Down Expand Up @@ -1426,7 +1434,7 @@ def _create_template_from_message_type(

def _convert_to_message(
message: MessageLikeRepresentation,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
template_format: PromptTemplateFormat = "f-string",
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
"""Instantiate a message from a variety of message formats.
Expand Down
6 changes: 4 additions & 2 deletions libs/core/langchain_core/prompts/few_shot_with_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
PromptTemplateFormat,
StringPromptTemplate,
)

Expand Down Expand Up @@ -36,8 +37,9 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
prefix: Optional[StringPromptTemplate] = None
"""A PromptTemplate to put before the examples."""

template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
template_format: PromptTemplateFormat = "f-string"
"""The format of the prompt template.
Options are: 'f-string', 'jinja2', 'mustache'."""

validate_template: bool = False
"""Whether or not to try validating the template."""
Expand Down
11 changes: 10 additions & 1 deletion libs/core/langchain_core/prompts/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
PromptTemplateFormat,
)
from langchain_core.runnables import run_in_executor
from langchain_core.utils import image as image_utils

Expand All @@ -13,6 +17,9 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):

template: dict = Field(default_factory=dict)
"""Template for the prompt."""
template_format: PromptTemplateFormat = "f-string"
"""The format of the prompt template.
Options are: 'f-string', 'mustache', 'jinja2'."""

def __init__(self, **kwargs: Any) -> None:
if "input_variables" not in kwargs:
Expand Down Expand Up @@ -85,7 +92,9 @@ def format(
formatted = {}
for k, v in self.template.items():
if isinstance(v, str):
formatted[k] = v.format(**kwargs)
formatted[k] = DEFAULT_FORMATTER_MAPPING[self.template_format](
v, **kwargs
)
else:
formatted[k] = v
url = kwargs.get("url") or formatted.get("url")
Expand Down
12 changes: 7 additions & 5 deletions libs/core/langchain_core/prompts/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

import warnings
from pathlib import Path
from typing import Any, Literal, Optional, Union
from typing import Any, Optional, Union

from pydantic import BaseModel, model_validator

from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
PromptTemplateFormat,
StringPromptTemplate,
check_valid_template,
get_template_variables,
Expand All @@ -24,7 +25,8 @@ class PromptTemplate(StringPromptTemplate):
A prompt template consists of a string template. It accepts a set of parameters
from the user that can be used to generate a prompt for a language model.
The template can be formatted using either f-strings (default) or jinja2 syntax.
The template can be formatted using either f-strings (default), jinja2,
or mustache syntax.
*Security warning*:
Prefer using `template_format="f-string"` instead of
Expand Down Expand Up @@ -67,7 +69,7 @@ def get_lc_namespace(cls) -> list[str]:
template: str
"""The prompt template."""

template_format: Literal["f-string", "mustache", "jinja2"] = "f-string"
template_format: PromptTemplateFormat = "f-string"
"""The format of the prompt template.
Options are: 'f-string', 'mustache', 'jinja2'."""

Expand Down Expand Up @@ -248,7 +250,7 @@ def from_template(
cls,
template: str,
*,
template_format: str = "f-string",
template_format: PromptTemplateFormat = "f-string",
partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> PromptTemplate:
Expand All @@ -270,7 +272,7 @@ def from_template(
Args:
template: The template to load.
template_format: The format of the template. Use `jinja2` for jinja2,
and `f-string` or None for f-strings.
`mustache` for mustache, and `f-string` for f-strings.
Defaults to `f-string`.
partial_variables: A dictionary of variables that can be used to partially
fill in the template. For example, if the template is
Expand Down
4 changes: 3 additions & 1 deletion libs/core/langchain_core/prompts/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from abc import ABC
from string import Formatter
from typing import Any, Callable
from typing import Any, Callable, Literal

from pydantic import BaseModel, create_model

Expand All @@ -16,6 +16,8 @@
from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env

PromptTemplateFormat = Literal["f-string", "mustache", "jinja2"]


def jinja2_formatter(template: str, /, **kwargs: Any) -> str:
"""Format a template using jinja2.
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/prompts/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import (
Any,
Callable,
Literal,
Optional,
Union,
)
Expand All @@ -15,6 +14,7 @@
ChatPromptTemplate,
MessageLikeRepresentation,
)
from langchain_core.prompts.string import PromptTemplateFormat
from langchain_core.runnables.base import (
Other,
Runnable,
Expand All @@ -38,7 +38,7 @@ def __init__(
schema_: Optional[Union[dict, type[BaseModel]]] = None,
*,
structured_output_kwargs: Optional[dict[str, Any]] = None,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
template_format: PromptTemplateFormat = "f-string",
**kwargs: Any,
) -> None:
schema_ = schema_ or kwargs.pop("schema")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3119,6 +3119,7 @@
'template': dict({
'url': 'data:image/jpeg;base64,{my_image}',
}),
'template_format': 'f-string',
}),
'lc': 1,
'name': 'ImagePromptTemplate',
Expand All @@ -3138,6 +3139,7 @@
'template': dict({
'url': 'data:image/jpeg;base64,{my_image}',
}),
'template_format': 'f-string',
}),
'lc': 1,
'name': 'ImagePromptTemplate',
Expand All @@ -3157,6 +3159,7 @@
'template': dict({
'url': '{my_other_image}',
}),
'template_format': 'f-string',
}),
'lc': 1,
'name': 'ImagePromptTemplate',
Expand All @@ -3177,6 +3180,7 @@
'detail': 'medium',
'url': '{my_other_image}',
}),
'template_format': 'f-string',
}),
'lc': 1,
'name': 'ImagePromptTemplate',
Expand All @@ -3195,6 +3199,7 @@
'template': dict({
'url': 'https://www.langchain.com/image.png',
}),
'template_format': 'f-string',
}),
'lc': 1,
'name': 'ImagePromptTemplate',
Expand All @@ -3213,6 +3218,7 @@
'template': dict({
'url': 'data:image/jpeg;base64,foobar',
}),
'template_format': 'f-string',
}),
'lc': 1,
'name': 'ImagePromptTemplate',
Expand All @@ -3231,6 +3237,7 @@
'template': dict({
'url': 'data:image/jpeg;base64,foobar',
}),
'template_format': 'f-string',
}),
'lc': 1,
'name': 'ImagePromptTemplate',
Expand Down
Loading

0 comments on commit 380449a

Please sign in to comment.