From 7fe78d21edecdd06902f4fc5accdc17e97f21270 Mon Sep 17 00:00:00 2001 From: Aamir Nazir Date: Wed, 15 Jan 2025 15:16:17 +0400 Subject: [PATCH] [Stable Diffusion v3 ]Fix Wrapper Class for SD3 FX Notebook (#2645) Fix the wrapper class for SD3 FX notebook to be more readable and efficient by fixing the wrapper classes for SD3 models. --- .../sd3_torch_fx_helper.py | 54 ++++--------------- 1 file changed, 11 insertions(+), 43 deletions(-) diff --git a/notebooks/stable-diffusion-v3/sd3_torch_fx_helper.py b/notebooks/stable-diffusion-v3/sd3_torch_fx_helper.py index b4253e5cfd9..51705a55854 100644 --- a/notebooks/stable-diffusion-v3/sd3_torch_fx_helper.py +++ b/notebooks/stable-diffusion-v3/sd3_torch_fx_helper.py @@ -19,15 +19,18 @@ def init_pipeline(models_dict, configs_dict, model_id="stabilityai/stable-diffus wrapped_models = {} def wrap_model(pipe_model, base_class, config): - base_class = (base_class,) if not isinstance(base_class, tuple) else base_class - - class WrappedModel(*base_class): + class ModelWrapper(base_class): def __init__(self, model, config): - cls_name = base_class[0].__name__ + cls_name = base_class.__name__ if isinstance(config, dict): super().__init__(**config) else: super().__init__(config) + + modules_to_delete = [name for name in self._modules.keys()] + for name in modules_to_delete: + del self._modules[name] + if cls_name == "AutoencoderKL": self.encoder = model.encoder self.decoder = model.decoder @@ -35,53 +38,18 @@ def __init__(self, model, config): self.model = model def forward(self, *args, **kwargs): + kwargs.pop("joint_attention_kwargs", None) + kwargs.pop("return_dict", None) return self.model(*args, **kwargs) - class WrappedTransformer(*base_class): - @register_to_config - def __init__( - self, - model, - sample_size, - patch_size, - in_channels, - num_layers, - attention_head_dim, - num_attention_heads, - joint_attention_dim, - caption_projection_dim, - pooled_projection_dim, - out_channels, - pos_embed_max_size, - dual_attention_layers, - qk_norm, - ): - super().__init__() - self.model = model - - def forward(self, *args, **kwargs): - del kwargs["joint_attention_kwargs"] - del kwargs["return_dict"] - return self.model(*args, **kwargs) - - if len(base_class) > 1: - return WrappedTransformer(pipe_model, **config) - return WrappedModel(pipe_model, config) + return ModelWrapper(pipe_model, config) - wrapped_models["transformer"] = wrap_model( - models_dict["transformer"], - ( - ModelMixin, - ConfigMixin, - ), - configs_dict["transformer"], - ) + wrapped_models["transformer"] = wrap_model(models_dict["transformer"], SD3Transformer2DModel, configs_dict["transformer"]) wrapped_models["vae"] = wrap_model(models_dict["vae"], AutoencoderKL, configs_dict["vae"]) wrapped_models["text_encoder"] = wrap_model(models_dict["text_encoder"], CLIPTextModelWithProjection, configs_dict["text_encoder"]) wrapped_models["text_encoder_2"] = wrap_model(models_dict["text_encoder_2"], CLIPTextModelWithProjection, configs_dict["text_encoder_2"]) pipe = StableDiffusion3Pipeline.from_pretrained(model_id, text_encoder_3=None, tokenizer_3=None, **wrapped_models) - return pipe