From fec00922343876cf03970fbda63cd280ff3db0de Mon Sep 17 00:00:00 2001
From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
Date: Thu, 10 Oct 2024 18:40:49 +0200
Subject: [PATCH] Include custom textual inversion to diffusers pipelines
 (#938)

* added textual inversion

* added tests

* fix textual inversion loader and test it

* fix

* slow test

* fix

* mark as run slow to test with CI
---
 optimum/intel/openvino/loaders.py            | 385 +++----------------
 optimum/intel/openvino/modeling_diffusion.py |  29 +-
 optimum/intel/openvino/utils.py              |   4 +-
 tests/openvino/test_diffusion.py             |  76 +++-
 4 files changed, 151 insertions(+), 343 deletions(-)

diff --git a/optimum/intel/openvino/loaders.py b/optimum/intel/openvino/loaders.py
index fc5ae97495..5da2877002 100644
--- a/optimum/intel/openvino/loaders.py
+++ b/optimum/intel/openvino/loaders.py
@@ -13,26 +13,18 @@
 #  limitations under the License.
 
 import logging
-import warnings
 from typing import Dict, List, Optional, Union
 
-import torch
-from diffusers.utils import _get_model_file
-
-from ..utils.import_utils import is_safetensors_available
-
-
-if is_safetensors_available():
-    import safetensors
-
 import openvino
-from huggingface_hub.constants import HF_HUB_OFFLINE, HUGGINGFACE_HUB_CACHE
+import torch
+from diffusers.loaders.textual_inversion import TextualInversionLoaderMixin, load_textual_inversion_state_dicts
+from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
 from openvino.runtime import Type
 from openvino.runtime import opset11 as ops
 from openvino.runtime.passes import Manager, Matcher, MatcherPass, WrapType
 from transformers import PreTrainedTokenizer
 
-from .utils import TEXTUAL_INVERSION_EMBEDDING_KEY, TEXTUAL_INVERSION_NAME, TEXTUAL_INVERSION_NAME_SAFE
+from .utils import TEXTUAL_INVERSION_EMBEDDING_KEY
 
 
 try:
@@ -49,17 +41,17 @@ class InsertTextEmbedding(MatcherPass):
     OpenVINO ngraph transformation for inserting pre-trained texual inversion embedding to text encoder
     """
 
-    def __init__(self, token_ids_and_embeddings):
+    def __init__(self, tokens_ids, embeddings):
         MatcherPass.__init__(self)
-        self.model_changed = False
+
         param = WrapType("opset1.Constant")
 
         def callback(matcher: Matcher) -> bool:
             root = matcher.get_match_root()
-            if root.get_friendly_name() == TEXTUAL_INVERSION_EMBEDDING_KEY:
+            if root.get_friendly_name() == TEXTUAL_INVERSION_EMBEDDING_KEY:  # there should be a better way to do this
                 add_ti = root
                 consumers = matcher.get_match_value().get_target_inputs()
-                for token_id, embedding in token_ids_and_embeddings:
+                for token_id, embedding in zip(tokens_ids, embeddings):
                     ti_weights = ops.constant(embedding, Type.f32, name=str(token_id))
                     ti_weights_unsqueeze = ops.unsqueeze(ti_weights, axes=0)
                     add_ti = ops.concat(
@@ -81,341 +73,74 @@ def callback(matcher: Matcher) -> bool:
 
 
 # Adapted from diffusers.loaders.TextualInversionLoaderMixin
-class OVTextualInversionLoaderMixin:
-    r"""
-    Load textual inversion tokens and embeddings to the tokenizer and text encoder.
-    """
-
-    def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"):
-        r"""
-        Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
-        be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
-        inversion token or if the textual inversion token is a single vector, the input prompt is returned.
-
-        Parameters:
-            prompt (`str` or list of `str`):
-                The prompt or prompts to guide the image generation.
-            tokenizer (`PreTrainedTokenizer`):
-                The tokenizer responsible for encoding the prompt into input tokens.
-
-        Returns:
-            `str` or list of `str`: The converted prompt
-        """
-        if not isinstance(prompt, List):
-            prompts = [prompt]
-        else:
-            prompts = prompt
-
-        prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
-
-        if not isinstance(prompt, List):
-            return prompts[0]
-
-        return prompts
-
-    def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
-        r"""
-        Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
-        to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
-        is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
-        inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
-
-        Parameters:
-            prompt (`str`):
-                The prompt to guide the image generation.
-            tokenizer (`PreTrainedTokenizer`):
-                The tokenizer responsible for encoding the prompt into input tokens.
-
-        Returns:
-            `str`: The converted prompt
-        """
-        tokens = tokenizer.tokenize(prompt)
-        unique_tokens = set(tokens)
-        for token in unique_tokens:
-            if token in tokenizer.added_tokens_encoder:
-                replacement = token
-                i = 1
-                while f"{token}_{i}" in tokenizer.added_tokens_encoder:
-                    replacement += f" {token}_{i}"
-                    i += 1
-
-                prompt = prompt.replace(token, replacement)
-
-        return prompt
-
+class OVTextualInversionLoaderMixin(TextualInversionLoaderMixin):
     def load_textual_inversion(
         self,
         pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
         token: Optional[Union[str, List[str]]] = None,
+        tokenizer: Optional["PreTrainedTokenizer"] = None,  # noqa: F821
+        text_encoder: Optional["openvino.runtime.Model"] = None,  # noqa: F821
         **kwargs,
     ):
-        r"""
-        Load textual inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
-        Automatic1111 formats are supported).
-
-        Parameters:
-            pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
-                Can be either one of the following or a list of them:
-
-                    - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
-                      pretrained model hosted on the Hub.
-                    - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
-                      inversion weights.
-                    - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
-                    - A [torch state
-                      dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
-            token (`str` or `List[str]`, *optional*):
-                Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
-                list, then `token` must also be a list of equal length.
-            weight_name (`str`, *optional*):
-                Name of a custom weight file. This should be used when:
-
-                    - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
-                      name such as `text_inv.bin`.
-                    - The saved textual inversion file is in the Automatic1111 format.
-            cache_dir (`Union[str, os.PathLike]`, *optional*):
-                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
-                is not used.
-            force_download (`bool`, *optional*, defaults to `False`):
-                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
-                cached versions if they exist.
-            resume_download (`bool`, *optional*, defaults to `False`):
-                Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
-                incompletely downloaded files are deleted.
-            proxies (`Dict[str, str]`, *optional*):
-                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
-                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
-            local_files_only (`bool`, *optional*, defaults to `False`):
-                Whether to only load local model weights and configuration files or not. If set to `True`, the model
-                won't be downloaded from the Hub.
-            use_auth_token (Optional[Union[bool, str]], defaults to `None`):
-                Deprecated. Please use `token` instead.
-            token (Optional[Union[bool, str]], defaults to `None`):
-                The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
-                when running `huggingface-cli login` (stored in `~/.huggingface`).
-            revision (`str`, *optional*, defaults to `"main"`):
-                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
-                allowed by Git.
-            subfolder (`str`, *optional*, defaults to `""`):
-                The subfolder location of a model file within a larger model repository on the Hub or locally.
-            mirror (`str`, *optional*):
-                Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
-                guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
-                information.
-
-        Example:
-
-        To load a textual inversion embedding vector in 🤗 Diffusers format:
-
-        ```py
-        from optimum.intel import OVStableDiffusionPipeline
-
-        model_id = "runwayml/stable-diffusion-v1-5"
-        pipe = OVStableDiffusionPipeline.from_pretrained(model_id, compile=False)
-
-        pipe.load_textual_inversion("sd-concepts-library/cat-toy")
-        pipe.compile()
-
-        prompt = "A <cat-toy> backpack"
-
-        image = pipe(prompt, num_inference_steps=50).images[0]
-        image.save("cat-backpack.png")
-        ```
-
-        To load a textual inversion embedding vector in Automatic1111 format, make sure to download the vector first
-        (for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector
-        locally:
-
-        ```py
-        from optimum.intel import OVStableDiffusionPipeline
-
-        model_id = "runwayml/stable-diffusion-v1-5"
-        pipe = StableDiffusionPipeline.from_pretrained(model_id, compile=False)
-
-        pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
-        pipe.compile()
-
-        prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
-
-        image = pipe(prompt, num_inference_steps=50).images[0]
-        image.save("character.png")
-        ```
-        """
-
-        if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
+        if not hasattr(self, "tokenizer"):
             raise ValueError(
-                f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
-                f" `{self.load_textual_inversion.__name__}`"
+                f"{self.__class__.__name__} requires `self.tokenizer` for calling `{self.load_textual_inversion.__name__}`"
             )
-
-        if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder.model, openvino.runtime.Model):
+        elif not isinstance(self.tokenizer, PreTrainedTokenizer):
             raise ValueError(
-                f"{self.__class__.__name__} requires `self.text_encoder.model` of type `openvino.runtime.Model` for calling"
-                f" `{self.load_textual_inversion.__name__}`"
-            )
-
-        cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
-        force_download = kwargs.pop("force_download", False)
-        resume_download = kwargs.pop("resume_download", False)
-        proxies = kwargs.pop("proxies", None)
-        local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
-        use_auth_token = kwargs.pop("use_auth_token", None)
-        token = kwargs.pop("token", None)
-        revision = kwargs.pop("revision", None)
-        subfolder = kwargs.pop("subfolder", None)
-        weight_name = kwargs.pop("weight_name", None)
-        use_safetensors = kwargs.pop("use_safetensors", None)
-
-        if use_auth_token is not None:
-            warnings.warn(
-                "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
-                FutureWarning,
+                f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling `{self.load_textual_inversion.__name__}`"
             )
-            if token is not None:
-                raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
-            token = use_auth_token
 
-        if use_safetensors and not is_safetensors_available():
+        if not hasattr(self, "text_encoder"):
             raise ValueError(
-                "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
+                f"{self.__class__.__name__} requires `self.text_encoder` for calling `{self.load_textual_inversion.__name__}`"
             )
-
-        allow_pickle = False
-        if use_safetensors is None:
-            use_safetensors = is_safetensors_available()
-            allow_pickle = True
-
-        user_agent = {
-            "file_type": "text_inversion",
-            "framework": "pytorch",
-        }
-
-        if not isinstance(pretrained_model_name_or_path, list):
-            pretrained_model_name_or_paths = [pretrained_model_name_or_path]
-        else:
-            pretrained_model_name_or_paths = pretrained_model_name_or_path
-
-        if isinstance(token, str):
-            tokens = [token]
-        elif token is None:
-            tokens = [None] * len(pretrained_model_name_or_paths)
-        else:
-            tokens = token
-
-        if len(pretrained_model_name_or_paths) != len(tokens):
+        elif not isinstance(self.text_encoder.model, openvino.runtime.Model):
             raise ValueError(
-                f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
-                f"Make sure both lists have the same length."
+                f"{self.__class__.__name__} requires `self.text_encoder` of type `openvino.runtime.Model` for calling `{self.load_textual_inversion.__name__}`"
             )
 
-        valid_tokens = [t for t in tokens if t is not None]
-        if len(set(valid_tokens)) < len(valid_tokens):
-            raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
-
-        token_ids_and_embeddings = []
-
-        for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
-            if not isinstance(pretrained_model_name_or_path, dict):
-                # 1. Load textual inversion file
-                model_file = None
-                # Let's first try to load .safetensors weights
-                if (use_safetensors and weight_name is None) or (
-                    weight_name is not None and weight_name.endswith(".safetensors")
-                ):
-                    try:
-                        model_file = _get_model_file(
-                            pretrained_model_name_or_path,
-                            weights_name=weight_name or TEXTUAL_INVERSION_NAME_SAFE,
-                            cache_dir=cache_dir,
-                            force_download=force_download,
-                            resume_download=resume_download,
-                            proxies=proxies,
-                            local_files_only=local_files_only,
-                            use_auth_token=token,  # still uses use_auth_token
-                            revision=revision,
-                            subfolder=subfolder,
-                            user_agent=user_agent,
-                        )
-                        state_dict = safetensors.torch.load_file(model_file, device="cpu")
-                    except Exception as e:
-                        if not allow_pickle:
-                            raise e
-
-                        model_file = None
-
-                if model_file is None:
-                    model_file = _get_model_file(
-                        pretrained_model_name_or_path,
-                        weights_name=weight_name or TEXTUAL_INVERSION_NAME,
-                        cache_dir=cache_dir,
-                        force_download=force_download,
-                        resume_download=resume_download,
-                        proxies=proxies,
-                        local_files_only=local_files_only,
-                        use_auth_token=token,  # still uses use_auth_token
-                        revision=revision,
-                        subfolder=subfolder,
-                        user_agent=user_agent,
-                    )
-                    state_dict = torch.load(model_file, map_location="cpu")
-            else:
-                state_dict = pretrained_model_name_or_path
-
-            # 2. Load token and embedding correcly from file
-            loaded_token = None
-            if isinstance(state_dict, torch.Tensor):
-                if token is None:
+        # 1. Set correct tokenizer and text encoder
+        tokenizer = tokenizer or getattr(self, "tokenizer", None)
+        text_encoder = text_encoder or getattr(self, "text_encoder", None)
+
+        # 2. Normalize inputs
+        pretrained_model_name_or_paths = (
+            [pretrained_model_name_or_path]
+            if not isinstance(pretrained_model_name_or_path, list)
+            else pretrained_model_name_or_path
+        )
+        tokens = [token] if not isinstance(token, list) else token
+        if tokens[0] is None:
+            tokens = tokens * len(pretrained_model_name_or_paths)
+
+        # 3. Check inputs
+        self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens)
+
+        # 4. Load state dicts of textual embeddings
+        state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
+
+        # 4.1 Handle the special case when state_dict is a tensor that contains n embeddings for n tokens
+        if len(tokens) > 1 and len(state_dicts) == 1:
+            if isinstance(state_dicts[0], torch.Tensor):
+                state_dicts = list(state_dicts[0])
+                if len(tokens) != len(state_dicts):
                     raise ValueError(
-                        "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
+                        f"You have passed a state_dict contains {len(state_dicts)} embeddings, and list of tokens of length {len(tokens)} "
+                        f"Make sure both have the same length."
                     )
-                embedding = state_dict
-            elif len(state_dict) == 1:
-                # diffusers
-                loaded_token, embedding = next(iter(state_dict.items()))
-            elif "string_to_param" in state_dict:
-                # A1111
-                loaded_token = state_dict["name"]
-                embedding = state_dict["string_to_param"]["*"]
-
-            if token is not None and loaded_token != token:
-                logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
-            else:
-                token = loaded_token
-
-            embedding = embedding.detach().cpu().numpy()
 
-            # 3. Make sure we don't mess up the tokenizer or text encoder
-            vocab = self.tokenizer.get_vocab()
-            if token in vocab:
-                raise ValueError(
-                    f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
-                )
-            elif f"{token}_1" in vocab:
-                multi_vector_tokens = [token]
-                i = 1
-                while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
-                    multi_vector_tokens.append(f"{token}_{i}")
-                    i += 1
+        # 4. Retrieve tokens and embeddings
+        tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer)
 
-                raise ValueError(
-                    f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
-                )
-            is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
-            if is_multi_vector:
-                tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
-                embeddings = [e for e in embedding]  # noqa: C416
-            else:
-                tokens = [token]
-                embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
-            # add tokens and get ids
-            self.tokenizer.add_tokens(tokens)
-            token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
-            token_ids_and_embeddings += zip(token_ids, embeddings)
+        # 5. Extend tokens and embeddings for multi vector
+        tokens, embeddings = self._extend_tokens_and_embeddings(tokens, embeddings, tokenizer)
 
-            logger.info(f"Loaded textual inversion embedding for {token}.")
+        # 7.4 add tokens to tokenizer (modified)
+        tokenizer.add_tokens(tokens)
+        token_ids = tokenizer.convert_tokens_to_ids(tokens)
 
         # Insert textual inversion embeddings to text encoder with OpenVINO ngraph transformation
         manager = Manager()
-        manager.register_pass(InsertTextEmbedding(token_ids_and_embeddings))
-        manager.run_passes(self.text_encoder.model)
+        manager.register_pass(InsertTextEmbedding(token_ids, embeddings))
+        manager.run_passes(text_encoder.model)
diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py
index 81dc085df9..d5ee6ee22e 100644
--- a/optimum/intel/openvino/modeling_diffusion.py
+++ b/optimum/intel/openvino/modeling_diffusion.py
@@ -64,6 +64,7 @@
 
 from ...exporters.openvino import main_export
 from .configuration import OVConfig, OVQuantizationMethod, OVWeightQuantizationConfig
+from .loaders import OVTextualInversionLoaderMixin
 from .modeling_base import OVBaseModel
 from .utils import (
     ONNX_WEIGHTS_NAME,
@@ -1010,7 +1011,7 @@ def to(self, *args, **kwargs):
             self.encoder.to(*args, **kwargs)
 
 
-class OVStableDiffusionPipeline(OVDiffusionPipeline, StableDiffusionPipeline):
+class OVStableDiffusionPipeline(OVDiffusionPipeline, OVTextualInversionLoaderMixin, StableDiffusionPipeline):
     """
     OpenVINO-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion#diffusers.StableDiffusionPipeline).
     """
@@ -1020,7 +1021,9 @@ class OVStableDiffusionPipeline(OVDiffusionPipeline, StableDiffusionPipeline):
     auto_model_class = StableDiffusionPipeline
 
 
-class OVStableDiffusionImg2ImgPipeline(OVDiffusionPipeline, StableDiffusionImg2ImgPipeline):
+class OVStableDiffusionImg2ImgPipeline(
+    OVDiffusionPipeline, OVTextualInversionLoaderMixin, StableDiffusionImg2ImgPipeline
+):
     """
     OpenVINO-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_img2img#diffusers.StableDiffusionImg2ImgPipeline).
     """
@@ -1030,7 +1033,9 @@ class OVStableDiffusionImg2ImgPipeline(OVDiffusionPipeline, StableDiffusionImg2I
     auto_model_class = StableDiffusionImg2ImgPipeline
 
 
-class OVStableDiffusionInpaintPipeline(OVDiffusionPipeline, StableDiffusionInpaintPipeline):
+class OVStableDiffusionInpaintPipeline(
+    OVDiffusionPipeline, OVTextualInversionLoaderMixin, StableDiffusionInpaintPipeline
+):
     """
     OpenVINO-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionInpaintPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_inpaint#diffusers.StableDiffusionInpaintPipeline).
     """
@@ -1040,7 +1045,7 @@ class OVStableDiffusionInpaintPipeline(OVDiffusionPipeline, StableDiffusionInpai
     auto_model_class = StableDiffusionInpaintPipeline
 
 
-class OVStableDiffusionXLPipeline(OVDiffusionPipeline, StableDiffusionXLPipeline):
+class OVStableDiffusionXLPipeline(OVDiffusionPipeline, OVTextualInversionLoaderMixin, StableDiffusionXLPipeline):
     """
     OpenVINO-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline).
     """
@@ -1063,7 +1068,9 @@ def _get_add_time_ids(
         return add_time_ids
 
 
-class OVStableDiffusionXLImg2ImgPipeline(OVDiffusionPipeline, StableDiffusionXLImg2ImgPipeline):
+class OVStableDiffusionXLImg2ImgPipeline(
+    OVDiffusionPipeline, OVTextualInversionLoaderMixin, StableDiffusionXLImg2ImgPipeline
+):
     """
     OpenVINO-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline).
     """
@@ -1100,7 +1107,9 @@ def _get_add_time_ids(
         return add_time_ids, add_neg_time_ids
 
 
-class OVStableDiffusionXLInpaintPipeline(OVDiffusionPipeline, StableDiffusionXLInpaintPipeline):
+class OVStableDiffusionXLInpaintPipeline(
+    OVDiffusionPipeline, OVTextualInversionLoaderMixin, StableDiffusionXLInpaintPipeline
+):
     """
     OpenVINO-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLInpaintPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLInpaintPipeline).
     """
@@ -1137,7 +1146,9 @@ def _get_add_time_ids(
         return add_time_ids, add_neg_time_ids
 
 
-class OVLatentConsistencyModelPipeline(OVDiffusionPipeline, LatentConsistencyModelPipeline):
+class OVLatentConsistencyModelPipeline(
+    OVDiffusionPipeline, OVTextualInversionLoaderMixin, LatentConsistencyModelPipeline
+):
     """
     OpenVINO-powered stable diffusion pipeline corresponding to [diffusers.LatentConsistencyModelPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/latent_consistency#diffusers.LatentConsistencyModelPipeline).
     """
@@ -1147,7 +1158,9 @@ class OVLatentConsistencyModelPipeline(OVDiffusionPipeline, LatentConsistencyMod
     auto_model_class = LatentConsistencyModelPipeline
 
 
-class OVLatentConsistencyModelImg2ImgPipeline(OVDiffusionPipeline, LatentConsistencyModelImg2ImgPipeline):
+class OVLatentConsistencyModelImg2ImgPipeline(
+    OVDiffusionPipeline, OVTextualInversionLoaderMixin, LatentConsistencyModelImg2ImgPipeline
+):
     """
     OpenVINO-powered stable diffusion pipeline corresponding to [diffusers.LatentConsistencyModelImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/latent_consistency_img2img#diffusers.LatentConsistencyModelImg2ImgPipeline).
     """
diff --git a/optimum/intel/openvino/utils.py b/optimum/intel/openvino/utils.py
index 4e8033880b..279a24818e 100644
--- a/optimum/intel/openvino/utils.py
+++ b/optimum/intel/openvino/utils.py
@@ -53,9 +53,7 @@
 
 EXTERNAL_DATA_FORMAT_SIZE_LIMIT = 2 * 1024 * 1024 * 1024
 
-TEXTUAL_INVERSION_NAME = "learned_embeds.bin"
-TEXTUAL_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
-TEXTUAL_INVERSION_EMBEDDING_KEY = "text_model.embeddings.token_embedding.weight"
+TEXTUAL_INVERSION_EMBEDDING_KEY = "self.text_model.embeddings.token_embedding.weight"
 
 OV_TO_NP_TYPE = {
     "boolean": np.bool_,
diff --git a/tests/openvino/test_diffusion.py b/tests/openvino/test_diffusion.py
index 6271ff3e4e..687c1f5c02 100644
--- a/tests/openvino/test_diffusion.py
+++ b/tests/openvino/test_diffusion.py
@@ -15,6 +15,7 @@
 import unittest
 
 import numpy as np
+import pytest
 import torch
 from diffusers import (
     AutoPipelineForImage2Image,
@@ -25,6 +26,7 @@
 from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
 from diffusers.utils import load_image
 from parameterized import parameterized
+from transformers.testing_utils import slow
 from utils_tests import MODEL_NAMES, SEED
 
 from optimum.intel.openvino import (
@@ -295,6 +297,30 @@ def test_height_width_properties(self, model_arch: str):
         self.assertEqual(ov_pipeline.height, height)
         self.assertEqual(ov_pipeline.width, width)
 
+    @pytest.mark.run_slow
+    @slow
+    @require_diffusers
+    def test_textual_inversion(self):
+        # for now we only test for stable-diffusion
+        # this is very slow and costly to run right now
+
+        model_id = "runwayml/stable-diffusion-v1-5"
+        ti_id = "sd-concepts-library/cat-toy"
+
+        inputs = self.generate_inputs()
+        inputs["prompt"] = "A <cat-toy> backpack"
+
+        diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(model_id, safety_checker=None)
+        diffusers_pipeline.load_textual_inversion(ti_id)
+
+        ov_pipeline = self.OVMODEL_CLASS.from_pretrained(model_id, compile=False, safety_checker=None)
+        ov_pipeline.load_textual_inversion(ti_id)
+
+        diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
+        ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
+
+        np.testing.assert_allclose(ov_output, diffusers_output, atol=1e-4, rtol=1e-2)
+
 
 class OVPipelineForImage2ImageTest(unittest.TestCase):
     SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"]
@@ -348,7 +374,6 @@ def test_num_images_per_prompt(self, model_arch: str):
     def test_callback(self, model_arch: str):
         height, width, batch_size = 32, 64, 1
         inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
-        inputs["num_inference_steps"] = 3
 
         class Callback:
             def __init__(self):
@@ -484,6 +509,30 @@ def test_height_width_properties(self, model_arch: str):
         self.assertEqual(ov_pipeline.height, height)
         self.assertEqual(ov_pipeline.width, width)
 
+    @pytest.mark.run_slow
+    @slow
+    @require_diffusers
+    def test_textual_inversion(self):
+        # for now we only test for stable-diffusion
+        # this is very slow and costly to run right now
+
+        model_id = "runwayml/stable-diffusion-v1-5"
+        ti_id = "sd-concepts-library/cat-toy"
+
+        inputs = self.generate_inputs()
+        inputs["prompt"] = "A <cat-toy> backpack"
+
+        diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(model_id, safety_checker=None)
+        diffusers_pipeline.load_textual_inversion(ti_id)
+
+        ov_pipeline = self.OVMODEL_CLASS.from_pretrained(model_id, compile=False, safety_checker=None)
+        ov_pipeline.load_textual_inversion(ti_id)
+
+        diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
+        ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
+
+        np.testing.assert_allclose(ov_output, diffusers_output, atol=1e-4, rtol=1e-2)
+
 
 class OVPipelineForInpaintingTest(unittest.TestCase):
     SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl"]
@@ -542,7 +591,6 @@ def test_num_images_per_prompt(self, model_arch: str):
     def test_callback(self, model_arch: str):
         height, width, batch_size = 32, 64, 1
         inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
-        inputs["num_inference_steps"] = 3
 
         class Callback:
             def __init__(self):
@@ -677,3 +725,27 @@ def test_height_width_properties(self, model_arch: str):
         )
         self.assertEqual(ov_pipeline.height, height)
         self.assertEqual(ov_pipeline.width, width)
+
+    @pytest.mark.run_slow
+    @slow
+    @require_diffusers
+    def test_textual_inversion(self):
+        # for now we only test for stable-diffusion
+        # this is very slow and costly to run right now
+
+        model_id = "runwayml/stable-diffusion-v1-5"
+        ti_id = "sd-concepts-library/cat-toy"
+
+        inputs = self.generate_inputs()
+        inputs["prompt"] = "A <cat-toy> backpack"
+
+        diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(model_id, safety_checker=None)
+        diffusers_pipeline.load_textual_inversion(ti_id)
+
+        ov_pipeline = self.OVMODEL_CLASS.from_pretrained(model_id, compile=False, safety_checker=None)
+        ov_pipeline.load_textual_inversion(ti_id)
+
+        diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
+        ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
+
+        np.testing.assert_allclose(ov_output, diffusers_output, atol=1e-4, rtol=1e-2)