Skip to content

Commit

Permalink
Support stable-diffusion-webui-forge
Browse files Browse the repository at this point in the history
Hijacks the forge_sample function and calls it separately for each individual
text conditioning. The outputs are then recombined into a format that the extension
already knows how to work with.

This doesn't use forge's patcher or take advantage of all of forge's performance
optimizations brought over from ComfyUI via ldm_patched,  but it does guarantee
backwards compatibility for base A1111 users, and allows Forge users to use the
extension without maintaining a separate install of A1111.
  • Loading branch information
wbclark committed Mar 21, 2024
1 parent 9420c78 commit f830995
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 16 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Neutral prompt is an a1111 webui extension that adds alternative composable diff

## Features

- Now compatible wih [stable-diffusion-webui-forge](https://github.com/lllyasviel/stable-diffusion-webui-forge)!
- [Perp-Neg](https://perp-neg.github.io/) orthogonal prompts, invoked using the `AND_PERP` keyword
- saliency-aware noise blending, invoked using the `AND_SALT` keyword (credits to [Magic Fusion](https://magicfusion.github.io/) for the algorithm used to determine SNB maps from epsilons)
- semantic guidance top-k filtering, invoked using the `AND_TOPK` keyword (reference: https://arxiv.org/abs/2301.12247)
Expand Down
126 changes: 110 additions & 16 deletions lib_neutral_prompt/cfg_denoiser_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,27 +212,121 @@ def filter_abs_top_k(vector: torch.Tensor, k_ratio: float) -> torch.Tensor:
return vector * (torch.abs(vector) >= top_k).to(vector.dtype)


sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
module=sd_samplers,
hijacker_attribute='__neutral_prompt_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)
try:
from modules_forge import forge_sampler
forge = True
except ImportError:
forge = False


@sd_samplers_hijacker.hijack('create_sampler')
def create_sampler_hijack(name: str, model, original_function):
sampler = original_function(name, model)
if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'):
if global_state.is_enabled:
warn_unsupported_sampler()
if forge:
from ldm_patched.modules.samplers import sampling_function

return sampler
forge_sampler_hijacker = hijacker.ModuleHijacker.install_or_get(
module=forge_sampler,
hijacker_attribute='__forge_sample_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)

@forge_sampler_hijacker.hijack('forge_sample')
def forge_sample(self, denoiser_params, cond_scale, cond_composition, original_function):
if not global_state.is_enabled:
return original_function(self, denoiser_params, cond_scale, cond_composition)

model = self.inner_model.inner_model.forge_objects.unet.model
control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list
extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition
x = denoiser_params.x
timestep = denoiser_params.sigma
model_options = self.inner_model.inner_model.forge_objects.unet.model_options
seed = self.p.seeds[0]

uncond = forge_sampler.cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond)
conds = forge_sampler.cond_from_a1111_to_patched_ldm_weighted(denoiser_params.text_cond, cond_composition)
conds += uncond

denoised = []

for current_cond in conds:
cond = [current_cond]
cond[0]['strength'] = 1.0

if extra_concat_condition is not None:
image_cond_in = extra_concat_condition
else:
image_cond_in = denoiser_params.image_cond

if isinstance(image_cond_in, torch.Tensor):
if image_cond_in.shape[0] == x.shape[0] \
and image_cond_in.shape[2] == x.shape[2] \
and image_cond_in.shape[3] == x.shape[3]:
cond[0]['model_conds']['c_concat'] = CONDRegular(image_cond_in)

if control is not None:
cond[0]['control'] = control

for modifier in model_options.get('conditioning_modifiers', []):
model, x, timestep, _, cond, cond_scale, model_options, seed = modifier(model, x, timestep, None, cond, cond_scale, model_options, seed)

model_options["disable_cfg1_optimization"] = True

result = sampling_function(model, x, timestep, None, cond, 1.0, model_options, seed)
denoised.append(result)

cond_indices = cond_composition[0]
prompt = global_state.prompt_exprs[0]

sampler.model_wrap_cfg.combine_denoised = functools.partial(
combine_denoised_hijack,
original_function=sampler.model_wrap_cfg.combine_denoised
# B, C, H, W
denoised_uncond = denoised[-1]

# N, B, C, H, W
denoised_conds = torch.stack(denoised[:-1], dim=0)

# N, 1, 1, 1, 1
weights = torch.tensor([ weight for (_, weight) in cond_indices ], device=denoised_uncond.device)
weights /= weights.abs().sum()
weights = weights.view(-1, 1, 1, 1, 1)

# B, C, H, W
assert denoised_conds.shape[0] == weights.shape[0]
denoised_cond = (denoised_conds * weights).sum(dim=0)
forge_denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cond_scale

for batch_i in range(denoised_uncond.shape[0]):
args = CombineDenoiseArgs(denoised_conds.unbind(dim=1)[batch_i], denoised_uncond[batch_i], cond_indices)
cond_delta = prompt.accept(CondDeltaVisitor(), args, 0)
aux_cond_delta = prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, 0)
cfg_cond = forge_denoised[batch_i] + aux_cond_delta * cond_scale
forge_denoised[batch_i] = cfg_rescale(cfg_cond, denoised_uncond[batch_i] + cond_delta + aux_cond_delta)

return forge_denoised


else:
sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
module=sd_samplers,
hijacker_attribute='__neutral_prompt_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)
return sampler


@sd_samplers_hijacker.hijack('create_sampler')
def create_sampler_hijack(name: str, model, original_function):
sampler = original_function(name, model)


if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'):
if global_state.is_enabled:
warn_unsupported_sampler()

return sampler

sampler.model_wrap_cfg.combine_denoised = functools.partial(
combine_denoised_hijack,
original_function=sampler.model_wrap_cfg.combine_denoised
)

return sampler


def warn_unsupported_sampler():
Expand Down

0 comments on commit f830995

Please sign in to comment.