Skip to content

Commit

Permalink
[Stable Diffusion v3 ]Fix Wrapper Class for SD3 FX Notebook (#2645)
Browse files Browse the repository at this point in the history
Fix the wrapper class for SD3 FX notebook to be more readable and
efficient by fixing the wrapper classes for SD3 models.
  • Loading branch information
anzr299 authored Jan 15, 2025
1 parent 3b22ba0 commit 7fe78d2
Showing 1 changed file with 11 additions and 43 deletions.
54 changes: 11 additions & 43 deletions notebooks/stable-diffusion-v3/sd3_torch_fx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,69 +19,37 @@ def init_pipeline(models_dict, configs_dict, model_id="stabilityai/stable-diffus
wrapped_models = {}

def wrap_model(pipe_model, base_class, config):
base_class = (base_class,) if not isinstance(base_class, tuple) else base_class

class WrappedModel(*base_class):
class ModelWrapper(base_class):
def __init__(self, model, config):
cls_name = base_class[0].__name__
cls_name = base_class.__name__
if isinstance(config, dict):
super().__init__(**config)
else:
super().__init__(config)

modules_to_delete = [name for name in self._modules.keys()]
for name in modules_to_delete:
del self._modules[name]

if cls_name == "AutoencoderKL":
self.encoder = model.encoder
self.decoder = model.decoder
else:
self.model = model

def forward(self, *args, **kwargs):
kwargs.pop("joint_attention_kwargs", None)
kwargs.pop("return_dict", None)
return self.model(*args, **kwargs)

class WrappedTransformer(*base_class):
@register_to_config
def __init__(
self,
model,
sample_size,
patch_size,
in_channels,
num_layers,
attention_head_dim,
num_attention_heads,
joint_attention_dim,
caption_projection_dim,
pooled_projection_dim,
out_channels,
pos_embed_max_size,
dual_attention_layers,
qk_norm,
):
super().__init__()
self.model = model

def forward(self, *args, **kwargs):
del kwargs["joint_attention_kwargs"]
del kwargs["return_dict"]
return self.model(*args, **kwargs)

if len(base_class) > 1:
return WrappedTransformer(pipe_model, **config)
return WrappedModel(pipe_model, config)
return ModelWrapper(pipe_model, config)

wrapped_models["transformer"] = wrap_model(
models_dict["transformer"],
(
ModelMixin,
ConfigMixin,
),
configs_dict["transformer"],
)
wrapped_models["transformer"] = wrap_model(models_dict["transformer"], SD3Transformer2DModel, configs_dict["transformer"])
wrapped_models["vae"] = wrap_model(models_dict["vae"], AutoencoderKL, configs_dict["vae"])
wrapped_models["text_encoder"] = wrap_model(models_dict["text_encoder"], CLIPTextModelWithProjection, configs_dict["text_encoder"])
wrapped_models["text_encoder_2"] = wrap_model(models_dict["text_encoder_2"], CLIPTextModelWithProjection, configs_dict["text_encoder_2"])

pipe = StableDiffusion3Pipeline.from_pretrained(model_id, text_encoder_3=None, tokenizer_3=None, **wrapped_models)

return pipe


Expand Down

0 comments on commit 7fe78d2

Please sign in to comment.