diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index c8c0191967d40..282ba2403b135 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -125,42 +125,31 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - if (data.bias == nullptr) { - assert(nullptr == fused_runner); - // For quantized attention, bias has been added so only need transpose here. - // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH - assert(qk_head_size == v_head_size); - int matrix_to_trans = (past_present_share_buffer ? 1 : 3); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.gemm_buffer, qkv, 3)); - data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } else { - // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) - // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3) - // For unfused kernel, transpose to 3xBxNxSxH (format 1) - // For fused causal kernel, use format 1 since we need have K and V to update present state, - // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. - const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1)); - data.qkv_format = use_fused_kernel - ? AttentionQkvFormat::QKV_BSN3H - : (use_flash_or_efficient_attention - ? AttentionQkvFormat::Q_K_V_BSNH - : (use_fused_causal - ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH - : AttentionQkvFormat::Q_K_V_BNSH)); - - // For fused causal, we will update gemm_buffer with bias directly. - T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; - - int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3); - // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v - // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H) - LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, - 3, parameters.do_rotary, parameters.rotary_embedding, - parameters.past_sequence_length); - } + // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) + // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3) + // For unfused kernel, transpose to 3xBxNxSxH (format 1) + // For fused causal kernel, use format 1 since we need have K and V to update present state, + // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. + const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1)); + data.qkv_format = use_fused_kernel + ? AttentionQkvFormat::QKV_BSN3H + : (use_flash_or_efficient_attention + ? AttentionQkvFormat::Q_K_V_BSNH + : (use_fused_causal + ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH + : AttentionQkvFormat::Q_K_V_BNSH)); + + // For fused causal, we will update gemm_buffer with bias directly. + T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; + + int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3); + // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v + // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H) + LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, + 3, parameters.do_rotary, parameters.rotary_embedding, + parameters.past_sequence_length); return Status::OK(); } diff --git a/onnxruntime/python/tools/transformers/compare_bert_results.py b/onnxruntime/python/tools/transformers/compare_bert_results.py index 03bcc20d9a5de..bfb19e08b4fe0 100644 --- a/onnxruntime/python/tools/transformers/compare_bert_results.py +++ b/onnxruntime/python/tools/transformers/compare_bert_results.py @@ -37,16 +37,23 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-1, atol=1e-3): # Validate the output of baseline and treatment, to make sure the results are similar. diff_count = 0 max_abs_diff = 0 + max_diff_percentage = 0 + case_passed = True for test_case_id, results in enumerate(baseline_results): - case_passed = True for i in range(len(results)): treatment_output = treatment_results[test_case_id][i] - abs_diff = np.amax(np.abs(treatment_output - results[i])) + abs_diff_tensor = np.abs(treatment_output - results[i]) + abs_diff = np.amax(abs_diff_tensor) if verbose and abs_diff > atol: print("abs_diff", abs_diff) print("treatment", treatment_output) print("baseline", results[i]) + count_exceeding = np.sum(abs_diff_tensor > atol) + total_elements = abs_diff_tensor.size + percentage_exceeding = (count_exceeding / total_elements) * 100 + max_diff_percentage = max(max_diff_percentage, percentage_exceeding) + max_abs_diff = max(max_abs_diff, abs_diff) if not np.allclose(results[i].tolist(), treatment_output.tolist(), rtol=rtol, atol=atol): if case_passed: @@ -66,6 +73,7 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-1, atol=1e-3): ) print(f"maximum absolute difference={max_abs_diff}") + print(f"maximum percentage of elements that exceeds atol={atol} is {max_diff_percentage:.3f}%") return max_abs_diff, case_passed diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 030708783bb61..56b5ae93e7221 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -355,45 +355,6 @@ def split_kv(self, present_k_name: str, present_v_name: str, kv_node: str): self.node_name_to_graph_name[gather_k_name] = self.this_graph_name self.node_name_to_graph_name[gather_v_name] = self.this_graph_name - def transpose_kv(self, past_k: str, past_v: str): - """Transpose past_k and past_v from (B,N,P,H) to (B,P,N,H) - - Args: - past_k (str): name of past K value of shape (B,N,P,H) - past_v (str): name of past V value of shape (B,N,P,H) - - Returns: - past_k_transpose (str): name of past K value of shape (B,P,N,H) - past_v_transpose (str): name of past V value of shape (B,P,N,H) - """ - past_k_transpose = (past_k + "_transposed").replace(".", "_") - past_v_transpose = (past_v + "_transposed").replace(".", "_") - transpose_k_name = self.model.create_node_name("Transpose") - transpose_v_name = self.model.create_node_name("Transpose") - - transpose_k = helper.make_node( - "Transpose", - inputs=[past_k], - outputs=[past_k_transpose], - name=transpose_k_name, - perm=[0, 2, 1, 3], - ) - transpose_v = helper.make_node( - "Transpose", - inputs=[past_v], - outputs=[past_v_transpose], - name=transpose_v_name, - perm=[0, 2, 1, 3], - ) - - # Add reshape nodes to graph - self.nodes_to_add.append(transpose_k) - self.nodes_to_add.append(transpose_v) - self.node_name_to_graph_name[transpose_k_name] = self.this_graph_name - self.node_name_to_graph_name[transpose_v_name] = self.this_graph_name - - return past_k_transpose, past_v_transpose - def create_combined_qkv_bias( self, q_add: NodeProto, diff --git a/onnxruntime/python/tools/transformers/fusion_fastgelu.py b/onnxruntime/python/tools/transformers/fusion_fastgelu.py index a9f46585faad7..e2bb8027c8608 100644 --- a/onnxruntime/python/tools/transformers/fusion_fastgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_fastgelu.py @@ -26,6 +26,9 @@ def fuse(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict): if self.fuse_3(tanh_node, input_name_to_nodes, output_name_to_node): return + if self.fuse_4(tanh_node, input_name_to_nodes, output_name_to_node): + return + def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optional[bool]: """ Fuse Gelu with tanh into one node: @@ -358,3 +361,122 @@ def fuse_3(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name return True + + def fuse_4(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]: + """ + This pattern is from stable diffusion 3.5 model. + Fuse Gelu with tanh into one node: + +-----------------+------------------+ + | | | + | v v + [root] ==> Mul --> Mul --> Mul -----> Add --> Mul --> Tanh --> Add -----> Mul --> Mul --> + | (A=0.0447) (A=0.7978) (A=1) ^ (A=0.5) + | | + +-------------------------------------------------------------------------+ + Note that constant input for Add and Mul could be first or second input. + """ + if tanh_node.output[0] not in input_name_to_nodes: + return + + children = input_name_to_nodes[tanh_node.output[0]] + if len(children) != 1 or children[0].op_type != "Add": + return + add_after_tanh = children[0] + + if not self.model.has_constant_input(add_after_tanh, 1.0): + return + + if add_after_tanh.output[0] not in input_name_to_nodes: + return + children = input_name_to_nodes[add_after_tanh.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return + mul_after_tanh = children[0] + + if mul_after_tanh.output[0] not in input_name_to_nodes: + return + children = input_name_to_nodes[mul_after_tanh.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return + mul_half = children[0] + if not self.model.has_constant_input(mul_half, 0.5): + return + + root_input = mul_after_tanh.input[0 if mul_after_tanh.input[1] == add_after_tanh.output[0] else 1] + + mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node) + if mul_before_tanh is None: + return + + i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001) + if i < 0: + return + + add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node) + if add_before_tanh is None: + return + + if add_before_tanh.input[0] == root_input: + another = 1 + elif add_before_tanh.input[1] == root_input: + another = 0 + else: + return + + mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", another, output_name_to_node) + if mul_after_pow is None: + return + + i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001) + if i < 0: + return + + mul = self.model.match_parent(mul_after_pow, "Mul", 0 if i == 1 else 1, output_name_to_node) + if mul is None: + return + + if mul.input[0] == root_input: + another = 1 + elif mul.input[1] == root_input: + another = 0 + else: + return + + mul2 = self.model.match_parent(mul, "Mul", another, output_name_to_node) + if mul2 is None: + return + + if mul2.input[0] != root_input or mul2.input[1] != root_input: + return + + subgraph_nodes = [ + mul2, + mul, + mul_after_pow, + add_before_tanh, + mul_before_tanh, + tanh_node, + add_after_tanh, + mul_after_tanh, + mul_half, + ] + + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, + [mul_half.output[0]], + input_name_to_nodes, + output_name_to_node, + ): + return + + self.nodes_to_remove.extend(subgraph_nodes) + fused_node = helper.make_node( + "FastGelu", + inputs=[root_input], + outputs=mul_half.output, + name=self.model.create_node_name("FastGelu"), + ) + fused_node.domain = "com.microsoft" + self.nodes_to_add.append(fused_node) + self.node_name_to_graph_name[fused_node.name] = self.this_graph_name + return True diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index c718d2c27e015..c9bf52234d696 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -84,6 +84,7 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): instance_norm_scale = self.model.get_constant_value(instance_norm.input[1]) if instance_norm_scale is None or len(instance_norm_scale.shape) != 1: return + num_groups = int(instance_norm_scale.shape[0]) instance_norm_bias = self.model.get_constant_value(instance_norm.input[2]) if instance_norm_bias is None or instance_norm_scale.shape != instance_norm_scale.shape: @@ -156,7 +157,8 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): ) new_node.attribute.extend(instance_norm.attribute) - new_node.attribute.extend([helper.make_attribute("groups", 32)]) + + new_node.attribute.extend([helper.make_attribute("groups", num_groups)]) new_node.attribute.extend([helper.make_attribute("activation", 1 if has_swish_activation else 0)]) if not self.channels_last: diff --git a/onnxruntime/python/tools/transformers/fusion_layernorm.py b/onnxruntime/python/tools/transformers/fusion_layernorm.py index aac05a7f01325..277bd0799cf16 100644 --- a/onnxruntime/python/tools/transformers/fusion_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_layernorm.py @@ -13,8 +13,10 @@ class FusionLayerNormalization(Fusion): - def __init__(self, model: OnnxModel): + def __init__(self, model: OnnxModel, check_constant_and_dimension: bool = True, force: bool = False): super().__init__(model, "LayerNormalization", "ReduceMean") + self.check_constant_and_dimension = check_constant_and_dimension + self.force = force def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): """ @@ -23,9 +25,9 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): | | | v [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add - (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^ - | | - +-----------------------------------------------+ + (axis=2 or -1) | (Y=2) (axis=2 or -1) (B=E-6 or E-12) ^ + | | + +-------------------------------------------------+ It also handles cases of duplicated sub nodes exported from older version of PyTorch: +----------------------+ @@ -56,18 +58,20 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): for child in children: # Check if Sub --> Div exists div_node_1 = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False) - - # Check if Sub --> Cast --> Div - div_node_2 = self.model.match_child_path(child, ["Cast", "Div"], exclude=[]) - if div_node_1 is not None: div_node = div_node_1 - elif div_node_2 is not None: - div_node = div_node_2[-1] + break + else: + # Check if Sub --> Cast --> Div + div_node_2 = self.model.match_child_path(child, ["Cast", "Div"]) + if div_node_2 is not None: + div_node = div_node_2[-1] + break + if div_node is None: return - path_id, parent_nodes, _ = self.model.match_parent_paths( + _path_id, parent_nodes, _ = self.model.match_parent_paths( div_node, [ (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]), @@ -75,72 +79,93 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): ], output_name_to_node, ) - if path_id < 0: + if parent_nodes is None: return sub_node = parent_nodes[-1] if sub_node not in children: return - second_add_node = parent_nodes[1] - i, add_weight = self.model.get_constant_input(second_add_node) - if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: - logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {add_weight}") + add_eps_node = parent_nodes[1] + i, epsilon = self.model.get_constant_input(add_eps_node) + if epsilon is None or epsilon <= 0 or epsilon > 1.0e-4: + logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {epsilon}") return pow_node = parent_nodes[3] if self.model.find_constant_input(pow_node, 2.0) != 1: return - temp_node = input_name_to_nodes[div_node.output[0]][0] - if temp_node.op_type == "Cast": - # Div --> Cast --> Mul - subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes - mul_node = input_name_to_nodes[temp_node.output[0]][0] - else: - # Div --> Mul - mul_node = temp_node - if mul_node.op_type != "Mul": - return - - last_add_node = input_name_to_nodes[mul_node.output[0]][0] - if last_add_node.op_type != "Add": - return - - subgraph_nodes.append(node) - subgraph_nodes.extend(children) - subgraph_nodes.extend(parent_nodes[:-1]) - - subgraph_nodes.extend([last_add_node, mul_node, div_node]) - if not self.model.is_safe_to_fuse_nodes( - subgraph_nodes, - last_add_node.output, - input_name_to_nodes, - output_name_to_node, - ): - logger.debug("It is not safe to fuse LayerNormalization node. Skip") - return - - node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node - weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)] - if not self.model.is_constant_with_specified_dimension(weight_input, 1, "layernorm weight"): - return - - bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)] - if not self.model.is_constant_with_specified_dimension(bias_input, 1, "layernorm bias"): - return - - self.nodes_to_remove.extend(subgraph_nodes) - - normalize_node = helper.make_node( - "LayerNormalization", - inputs=[node.input[0], weight_input, bias_input], - outputs=[last_add_node.output[0]], - name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"), - ) - normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) - self.nodes_to_add.append(normalize_node) - self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name + if div_node.output[0] not in input_name_to_nodes: + return + + # In MMDit model, Div might have two Mul+Add children paths. + div_children = input_name_to_nodes[div_node.output[0]] + for temp_node in div_children: + if temp_node.op_type == "Cast": + # Div --> Cast --> Mul + subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes + if temp_node.output[0] not in input_name_to_nodes: + continue + mul_node = input_name_to_nodes[temp_node.output[0]][0] + else: + # Div --> Mul + mul_node = temp_node + if mul_node.op_type != "Mul": + continue + + if mul_node.output[0] not in input_name_to_nodes: + continue + last_add_node = input_name_to_nodes[mul_node.output[0]][0] + if last_add_node.op_type != "Add": + continue + + subgraph_nodes.append(node) + subgraph_nodes.extend(children) + subgraph_nodes.extend(parent_nodes[:-1]) + + subgraph_nodes.extend([last_add_node, mul_node, div_node]) + + node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node + weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)] + if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension( + weight_input, 1, "layernorm weight" + ): + continue + + bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)] + if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension( + bias_input, 1, "layernorm bias" + ): + continue + + layer_norm_output = last_add_node.output[0] + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, + last_add_node.output, + input_name_to_nodes, + output_name_to_node, + ): + # If it is not safe to fuse, somce computation may be duplicated if we force to fuse it. + # It it unknown that force fusion might bring performance gain/loss. + # User need test performance impact to see whether forcing fusion can help. + if self.force: + self.prune_graph = True + else: + logger.debug("It is not safe to fuse LayerNormalization node. Skip") + continue + else: + self.nodes_to_remove.extend(subgraph_nodes) + + normalize_node = helper.make_node( + "LayerNormalization", + inputs=[node.input[0], weight_input, bias_input], + outputs=[layer_norm_output], + name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"), + ) + normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))]) + self.nodes_to_add.append(normalize_node) + self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name class FusionLayerNormalizationNCHW(Fusion): @@ -218,9 +243,9 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): if sub != sub_node: return - i, add_weight = self.model.get_constant_input(second_add_node) - if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: - logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {add_weight}") + i, epsilon = self.model.get_constant_input(second_add_node) + if epsilon is None or epsilon <= 0 or epsilon > 1.0e-4: + logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {epsilon}") return axes = OnnxModel.get_node_attribute(reduce_mean_node, "axes") @@ -286,7 +311,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): outputs=[layernorm_node_name + "_out_nhwc"], name=layernorm_node_name, ) - normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) + normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))]) self.nodes_to_add.append(transpose_input) self.nodes_to_add.append(normalize_node) diff --git a/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py b/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py new file mode 100644 index 0000000000000..dcad55c13eb49 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py @@ -0,0 +1,668 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import Dict, Optional + +import numpy as np +from fusion_base import Fusion +from fusion_utils import FusionUtils +from onnx import NodeProto, TensorProto, helper, numpy_helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionMultiHeadAttentionMMDit(Fusion): + """ + Fuse MultiHeadAttention for Multimodal Diffusion Transformer (MMDiT). + """ + + def __init__(self, model: OnnxModel): + super().__init__(model, fused_op_type="MultiHeadAttention", search_op_types=["Softmax"]) + self.unsqueeze_update_map = {} + + def get_num_heads(self, start_node: NodeProto, output_name_to_node, input_index=0) -> int: + """ + Detect num_heads from Reshape & Transpose of q/k/v for both Stable Diffusion 3.x and Flux 1.x: + + MatMul .. [-1] [24] .. + | | | / / + Add Concat(axis=0) + | / + Reshape + | + Transpose(perm=0,1,3,2) + | + (start_node) + """ + nodes = self.model.match_parent_path( + start_node, ["Transpose", "Reshape", "Concat"], [input_index, 0, 1], output_name_to_node=output_name_to_node + ) + if nodes is None: + return 0 + + concat_shape = nodes[-1] + if len(concat_shape.input) != 4: + return 0 + + value = self.model.get_constant_value(concat_shape.input[2]) + if value is None: + return 0 + + if len(value.shape) != 1: + return 0 + + return int(value[0]) + + def get_num_heads_from_k(self, transpose_k: NodeProto, output_name_to_node, concat_before_transpose: bool) -> int: + """ + Detect num_heads from subgraph like the following (num_heads=24 in this example): + MatMu .. [-1] [24] .. + | | | / / + Add Concat + | / + Reshape + | + Transpose(perm=0,2,1,3) + | + SimplifiedLayerNormalization + | + Transpose(perm=0,1,3,2) + + Another variant is to an extra Concat node to join two symmetrical subgraphs: + + | | + MatMul MatMul .. [-1] [24] .. + | | | | / / + Add Concat Add Concat + | / | / + Reshape Reshape + | | + Transpose Transpose(perm=0,2,1,3) + | | + SimplifiedLayerNormalization SimplifiedLayerNormalization + | / + Concat + | + Transpose(perm=0,1,3,2) + + Both patterns are used in stable diffusion 3.5 model. + """ + if concat_before_transpose: + nodes = self.model.match_parent_path( + transpose_k, ["Concat", "SimplifiedLayerNormalization"], [0, 1], output_name_to_node=output_name_to_node + ) + if nodes: + return self.get_num_heads(nodes[1], output_name_to_node) + else: + nodes = self.model.match_parent_path( + transpose_k, ["SimplifiedLayerNormalization"], [0], output_name_to_node=output_name_to_node + ) + if nodes: + return self.get_num_heads(nodes[0], output_name_to_node) + + return 0 + + def reshape_to_3d(self, input_name: str, output_name: str) -> str: + """Add a Reshape node to convert 4D BxSxNxH to 3D BxSxD. + + Args: + input_name (str): input name for the 4D tensor of shape BxSxNxH. + output_name (str): output name for the 3D tensor of shape BxSxD, where D = N * H. + + Returns: + str: the output name + """ + + new_dims_name = "bsnh_to_bsd_reshape_dims" + new_dims = self.model.get_initializer(new_dims_name) + if new_dims is None: + new_dims = numpy_helper.from_array(np.array([0, 0, -1], dtype="int64"), name=new_dims_name) + self.model.add_initializer(new_dims, self.this_graph_name) + reshape_q = helper.make_node( + "Reshape", + inputs=[input_name, new_dims_name], + outputs=[output_name], + name=self.model.create_node_name("Reshape"), + ) + self.nodes_to_add.append(reshape_q) + self.node_name_to_graph_name[reshape_q.name] = self.this_graph_name + return reshape_q.output[0] + + def adjust_query_from_bnsh_to_bsd_no_concat(self, mul_q: NodeProto, output_name_to_node) -> Optional[str]: + """ + MultiHeadAttenion requires query in BSD format. This function adjusts query from BNSH to BSD format. + + Before: + MatMul + | + Add Concat + | / + Reshape + | + Transpose(perm=0,2,1,3) + | + SimplifiedLayerNorm + | + Mul + + After: + MatMul + | + Add Concat + | / + Reshape + | + SimplifiedLayerNorm + | + Reshape (shape=[0, 0, -1]) + """ + + path = self.model.match_parent_path( + mul_q, + ["SimplifiedLayerNormalization", "Transpose"], + [0, 0], + ) + if path is None: + return None + sln_a, transpose_a = path + + if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]): + return None + + # Update the graph + sln_a.input[0] = transpose_a.input[0] + sln_output = sln_a.output[0] + sln_a.output[0] = sln_output + "_BSNH" + + return self.reshape_to_3d(sln_a.output[0], sln_output + "_BSD") + + def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> Optional[str]: + """ + MultiHeadAttenion requires query in BSD format. This function adjusts query from BNSH to BSD format. + + Before: + MatMul MatMul + | | + Add Concat Add Concat + | / | / + Reshape Reshape + | | + Transpose(perm=0,2,1,3) Transpose(perm=0,2,1,3) + | | + SimplifiedLayerNorm SimplifiedLayerNorm + | / + Concat(axis=2) + | + Mul + + After: + MatMul MatMul + | | + Add Concat Add Concat + | / | / + Reshape Reshape + | | + SimplifiedLayerNorm SimplifiedLayerNorm + | / + Concat(axis=1) + | + Reshape (shape=[0, 0, -1]) + """ + + path = self.model.match_parent_path( + mul_q, + ["Concat", "SimplifiedLayerNormalization", "Transpose"], + [0, 0, 0], + ) + if path is None: + return None + concat, sln_a, transpose_a = path + + if len(concat.input) != 2: + return None + + path = self.model.match_parent_path( + concat, + ["SimplifiedLayerNormalization", "Transpose"], + [1, 0], + ) + if path is None: + return None + sln_b, transpose_b = path + + if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]): + return None + + if not FusionUtils.check_node_attribute(transpose_b, "perm", [0, 2, 1, 3]): + return None + + if not FusionUtils.check_node_attribute(concat, "axis", 2): + return None + + # Update the graph + sln_a.input[0] = transpose_a.input[0] + sln_b.input[0] = transpose_b.input[0] + + new_concat_node = helper.make_node( + "Concat", + inputs=[sln_a.output[0], sln_b.output[0]], + outputs=[concat.output[0] + "_BSNH"], + name=self.model.create_node_name("Concat"), + axis=1, + ) + self.nodes_to_add.append(new_concat_node) + self.node_name_to_graph_name[new_concat_node.name] = self.this_graph_name + + return self.reshape_to_3d(new_concat_node.output[0], concat.output[0] + "_BSD") + + def update_unsqueeze_axes_1_to_2(self, unsqueeze: NodeProto) -> str: + updated_unsqueeze_output = self.unsqueeze_update_map.get(unsqueeze.name) + if updated_unsqueeze_output is None: + if len(unsqueeze.input) == 1: + new_node = helper.make_node( + "Unsqueeze", + inputs=unsqueeze.input, + outputs=[unsqueeze.output[0] + "_BSNH"], + name=self.model.create_node_name("Unsqueeze"), + axes=[2], + ) + else: + initializer_name = "unsqueeze_axes_2" + if self.model.get_initializer(initializer_name) is None: + unsqueeze_axes_2 = helper.make_tensor( + name=initializer_name, + data_type=TensorProto.INT64, + dims=[1], # Shape of the tensor + vals=[2], # Tensor values + ) + self.model.add_initializer(unsqueeze_axes_2, self.this_graph_name) + + new_node = helper.make_node( + "Unsqueeze", + inputs=[unsqueeze.input[0], initializer_name], + outputs=[unsqueeze.output[0] + "_BSNH"], + name=self.model.create_node_name("Unsqueeze"), + ) + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + updated_unsqueeze_output = new_node.output[0] + self.unsqueeze_update_map[unsqueeze.name] = updated_unsqueeze_output + + return updated_unsqueeze_output + + def update_unsqueeze_axes(self, add: NodeProto, output_name_to_node: Dict[str, NodeProto]) -> bool: + """ + Update axes of Unsqueeze from [1] to [2] in the following pattern: + Unsqueeze Unsqueeze + (axes=[0]) (axes=[0]) + | | + Unsqueeze Unsqueeze + ... (axes=[1]) ... (axes=[1]) + | / | / + Mul Mul + | / + Add + Args: + add (NodeProto): the Add node + output_name_to_node (Dict[str, NodeProto]): mapping from output name to node + + Returns: + bool: True if the pattern is matched and updated successfully, False otherwise. + """ + if len(add.input) != 2: + return False + + # Check axes of Unsqueeze nodes are [0] and [1], and change to [0] and [2] respectively. + nodes_b = self.model.match_parent_path(add, ["Mul", "Unsqueeze", "Unsqueeze"], [1, 1, 0], output_name_to_node) + if nodes_b is None: + return False + + fusion_utils = FusionUtils(self.model) + axes_1 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_b[1]) + if axes_1 is None or axes_1 != [1]: + return False + + axes_0 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_b[2]) + if axes_0 is None or axes_0 != [0]: + return False + + # Check axes of Unsqueeze nodes are [0] and [1], and change to [0] and [2] respectively. + nodes_a = self.model.match_parent_path(add, ["Mul", "Unsqueeze", "Unsqueeze"], [0, 1, 0], output_name_to_node) + if nodes_a is None: + return False + + axes_1 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_a[1]) + if axes_1 is None or axes_1 != [1]: + return False + + axes_0 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_a[2]) + if axes_0 is None or axes_0 != [0]: + return False + + nodes_a[0].input[1] = self.update_unsqueeze_axes_1_to_2(nodes_a[1]) + nodes_b[0].input[1] = self.update_unsqueeze_axes_1_to_2(nodes_b[1]) + return True + + def adjust_flux_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> Optional[str]: + """ + Adjust graph to change query format from BNSH to BSD for Flux model. + Note that the graph pattern is complex, and we only do a shallow match here. + + Before: + | | + Transpose(perm=0,2,1,3) Transpose(perm=0,2,1,3) + | | + SimplifiedLayerNorm SimplifiedLayerNorm + | / + Concat(axis=2) + | + Mul Mul + | / + Add + | + Mul + + After (Transpose nods are removed, and a Reshape is added): + + | | + SimplifiedLayerNorm SimplifiedLayerNorm + | / + Concat(axis=1) + | + Mul Mul + | / + Add + | + Reshape (shape=[0, 0, -1]) + """ + + path = self.model.match_parent_path( + mul_q, + ["Add", "Mul", "Concat", "SimplifiedLayerNormalization", "Transpose"], + [0, 0, 0, 0, 0], + ) + if path is None: + return None + add, _mul_a, concat, sln_a, transpose_a = path + + if len(concat.input) != 2: + return None + + path = self.model.match_parent_path( + concat, + ["SimplifiedLayerNormalization", "Transpose"], + [1, 0], + ) + if path is None: + return None + sln_b, transpose_b = path + + if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]): + return None + + if not FusionUtils.check_node_attribute(transpose_b, "perm", [0, 2, 1, 3]): + return None + + if not FusionUtils.check_node_attribute(concat, "axis", 2): + return None + + # Need adjust axes of Unsqueeze nodes from [1] to [2] so that the tensors to Mul nodes are BSNH instead of BNSH. + if not self.update_unsqueeze_axes(add, output_name_to_node): + return None + + # Update the graph + sln_a.input[0] = transpose_a.input[0] + sln_b.input[0] = transpose_b.input[0] + + new_concat_node = helper.make_node( + "Concat", + inputs=[sln_a.output[0], sln_b.output[0]], + outputs=[concat.output[0] + "_BSNH"], + name=self.model.create_node_name("Concat"), + axis=1, + ) + self.nodes_to_add.append(new_concat_node) + self.node_name_to_graph_name[new_concat_node.name] = self.this_graph_name + self.model.replace_input_of_all_nodes(concat.output[0], new_concat_node.output[0]) + + return self.reshape_to_3d(add.output[0], add.output[0] + "_BSD") + + def adjust_flux_single_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> Optional[str]: + """ + Adjust graph to change query format from BNSH to BSD for Flux model. + Note that the graph pattern is complex, and we only do a shallow match here. + + Before: + | + Transpose(perm=0,2,1,3) + | + SimplifiedLayerNorm + | + Mul Mul + | / + Add + | + Mul + + After (Transpose is removed, and a Reshape is added): + + | + SimplifiedLayerNorm + | + Mul Mul + | / + Add + | + Reshape (shape=[0, 0, -1]) + """ + + path = self.model.match_parent_path( + mul_q, + ["Add", "Mul", "SimplifiedLayerNormalization", "Transpose"], + [0, 0, 0, 0], + ) + if path is None: + return None + add, _mul_a, sln_a, transpose_a = path + + if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]): + return None + + # Need adjust axes of Unsqueeze nodes from [1] to [2] so that the tensors to Mul nodes are BSNH instead of BNSH. + if not self.update_unsqueeze_axes(add, output_name_to_node): + return None + + # Update the graph + sln_a.input[0] = transpose_a.input[0] + add.output[0] = add.output[0] + "_BSNH" + + return self.reshape_to_3d(add.output[0], add.output[0] + "_BSD") + + def transpose_reshape_bnsh_to_bsd(self, q: str, output_name_to_node) -> Optional[str]: + transpose_q = helper.make_node( + "Transpose", + [q], + [q + "_BSNH"], + name=self.model.create_node_name("Transpose", name_prefix="Transpose_BNSH_to_BSNH"), + perm=[0, 2, 1, 3], + ) + self.nodes_to_add.append(transpose_q) + self.node_name_to_graph_name[transpose_q.name] = self.this_graph_name + + return self.reshape_to_3d(q + "_BSNH", q + "_BSD") + + def create_multihead_attention_node( + self, + q: str, + k: str, + v: str, + output: str, + num_heads: int, + ) -> NodeProto: + """ + Create a MultiHeadAttention node. + + Args: + q (str): name of q + k (str): name of k + v (str): name of v + output (str): output name of MHA + num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. + + Returns: + NodeProto: the node created. + """ + + assert num_heads > 0 + + # Add inputs for MHA: Query, Key, Value (Proj_Bias, Mask, Attention_Bias, Past_K, Past_V are optional) + mha_inputs = [q, k, v] + + # Add outputs for MHA (Present_K, Present_V are optional) + mha_outputs = [output] + + mha_node = helper.make_node( + "MultiHeadAttention", + inputs=mha_inputs, + outputs=mha_outputs, + name=self.model.create_node_name("MultiHeadAttention"), + ) + + mha_node.domain = "com.microsoft" + mha_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + + # No mask is used in MMDit model, so we need not set the optional mask_filter_value attribute. + return mha_node + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + assert node.op_type == "Softmax" + softmax = node + + # Softmax output shall not be graph output. + if self.model.find_graph_output(softmax.output[0]): + return + + nodes = self.model.match_child_path( + softmax, ["MatMul", "Transpose", "Reshape"], [(0, 0), (0, 0), (0, 0)], input_name_to_nodes + ) + if nodes is None: + return + + matmul_s_v, transpose_out, reshape_out = nodes + if not FusionUtils.check_node_attribute(transpose_out, "perm", [0, 2, 1, 3]): + return + + q_nodes = self.model.match_parent_path( + softmax, + ["MatMul", "Mul", "Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape"], + [0, 0, 1, 0, 1, 0, 0, 0], + ) + + if q_nodes is None: + return + + matmul_qk, mul_q, sqrt_q_2, div_q, sqrt_q, _, _, shape_q = q_nodes + + q_bnsh = mul_q.input[0] + if q_bnsh != shape_q.input[0]: + return + + k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose"], [1, 0]) + if k_nodes is None: + return + + mul_k, transpose_k = k_nodes + k = transpose_k.input[0] + if not FusionUtils.check_node_attribute(transpose_k, "perm", [0, 1, 3, 2]): + return + + k_scale_nodes = self.model.match_parent_path(mul_k, ["Sqrt", "Div"], [1, 0]) + if k_scale_nodes is None: + return + if k_scale_nodes[0].input[0] != sqrt_q_2.input[0]: + return + + v = matmul_s_v.input[1] + + # Here we sanity check the v path to make sure it is in the expected BNSH format. + concat_v = self.model.match_parent(matmul_s_v, "Concat", input_index=1, output_name_to_node=output_name_to_node) + if concat_v is not None: + # Match v path like: + # -- Transpose (perm=[0,2,1,3]) ----+ + # | + # v + # -- Transpose (perm=[0,2,1,3]) -> Concat -> (v) + transpose_1 = self.model.match_parent( + concat_v, "Transpose", input_index=0, output_name_to_node=output_name_to_node + ) + if transpose_1 is None: + return + if not FusionUtils.check_node_attribute(transpose_1, "perm", [0, 2, 1, 3]): + return + + transpose_2 = self.model.match_parent( + concat_v, "Transpose", input_index=1, output_name_to_node=output_name_to_node + ) + if transpose_2 is None: + return + if not FusionUtils.check_node_attribute(transpose_2, "perm", [0, 2, 1, 3]): + return + else: + # Match v path like: + # -- Transpose (perm=[0,2,1,3]) -> (v) + transpose_1 = self.model.match_parent( + matmul_s_v, "Transpose", input_index=1, output_name_to_node=output_name_to_node + ) + if transpose_1 is None: + return + if not FusionUtils.check_node_attribute(transpose_1, "perm", [0, 2, 1, 3]): + return + + # Match patterns for Flux. + num_heads = ( + self.get_num_heads(concat_v, output_name_to_node) + if concat_v + else self.get_num_heads(matmul_s_v, output_name_to_node, input_index=1) + ) + + if num_heads == 0: + # Match patterns for Stable Diffusion 3.5. + num_heads = self.get_num_heads_from_k(transpose_k, output_name_to_node, concat_v is not None) + if num_heads <= 0: + return + + # Q is in BNSH format, we need to adjust it to BSD format due to limitation of MHA op. + # TODO: MHA op support BNSH format to reduce the effort in fusion. + if concat_v is not None: + query = self.adjust_query_from_bnsh_to_bsd(mul_q, output_name_to_node) + else: + query = self.adjust_query_from_bnsh_to_bsd_no_concat(mul_q, output_name_to_node) + + if query is None: + query = self.adjust_flux_query_from_bnsh_to_bsd(mul_q, output_name_to_node) + if query is None: + query = self.adjust_flux_single_query_from_bnsh_to_bsd(mul_q, output_name_to_node) + if query is None: + # fallback to use Transpose and Add to adjust query from BNSH to BSD + # This is more general approach. + # However, it might be slower if the extra Transpose node cannot be removed by ORT optimizer. + query = self.transpose_reshape_bnsh_to_bsd(q_bnsh, output_name_to_node) + + new_node = self.create_multihead_attention_node( + q=query, + k=k, + v=v, + output=reshape_out.output[0], + num_heads=num_heads, + ) + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([matmul_s_v, transpose_out, reshape_out]) + + # Use prune graph to remove nodes + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py index a872b8c2075bc..ca7ff6462b9ff 100644 --- a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py @@ -18,134 +18,113 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): return sim_ln_nodes = None - # SimplifiedLayerNorm calculation (notation from https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#summary): - # DD = Pow(D, 2) - # Var = ReduceMean(DD) - # VarEps = Add(Var, epsilon) - # StdDev = Sqrt(VarEps) - # InvStdDev = Div(1, StdDev) - # Normalized = Mul(D, InvStdDev) - # NormalizedScaled = Mul(Normalized, Scale) - - # SimplifiedLayerNorm - # +-------------------------------------------------------+ - # | | - # Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul - # | - # node - sim_ln_nodes_1 = self.model.match_parent_path( - node, - ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"], - [1, 1, 1, 0, 0, 0, 0], - ) - # SimplifiedLayerNorm - # +-------------------------------------------------------+ - # | | - # Gather --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul - # | - # node - sim_ln_nodes_2 = self.model.match_parent_path( - node, - ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Gather"], - [1, 1, 1, 0, 0, 0, 0], - ) - - # For LLaMA from Microsoft custom export: - # sim_ln_nodes_3 uses a different start parent index than sim_ln_nodes_1 + # RMSNorm formula: + # S = Pow(X, 2) or S = Mul(X, X) + # MS = ReduceMean(S) + # MSEps = Add(MS, epsilon) + # RMS = Sqrt(MSEps) + # InvRMS = Div(1, RMS) or InvRMS = Reciprocal(RMS) + # Normalized = Mul(D, InvRMS) + # Y = Mul(Normalized, Scale) # - # SimplifiedLayerNorm - # +-------------------------------------------------------+ - # | | - # Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul - # | - # node - sim_ln_nodes_3 = self.model.match_parent_path( - node, - ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"], - [0, 1, 1, 0, 0, 0, 0], - ) - - # sim_ln_nodes_4 starts with a graph input instead of an Add node like sim_ln_nodes_3 + # (root_input) ----------------------------------------+ + # | | + # v v + # Pow --> ReduceMean --> Add ---> Sqrt --> Div --> Mul --> Mul (node) + # (B=2) (A/B=eps) (A=1) (A/B=scale) # - # SimplifiedLayerNorm - # +-----------------------------------------------+ - # | | - # graph_input --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul - # | - # node - sim_ln_nodes_4 = self.model.match_parent_path( - node, - ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow"], - [0, 1, 1, 0, 0, 0], - ) - - # For Gemma from Microsoft custom export, which has a Multiply after the Gather: + # (root_input) ----------------------------------------+ + # | | | + # v v v + # Mul --> ReduceMean --> Add ---> Sqrt --> Div --> Mul --> Mul (node) + # (B=2) (A/B=eps) (A=1) (A/B=scale) # - # SimplifiedLayerNorm - # +-------------------------------------------------------+ - # | | - # Mul --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul - # | - # node - sim_ln_nodes_5 = self.model.match_parent_path( + return_indice = [] + sim_ln_nodes = self.model.match_parent_path( node, - ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Mul"], - [1, 1, 1, 0, 0, 0, 0], + ["Mul", "Div", "Sqrt", "Add", "ReduceMean"], + [None, 1, 1, 0, None], + output_name_to_node=output_name_to_node, + return_indice=return_indice, ) - add_node, pow_node = None, None - if sim_ln_nodes_1 is not None: - sim_ln_nodes = sim_ln_nodes_1 - add_node = sim_ln_nodes[3] - pow_node = sim_ln_nodes[-2] - elif sim_ln_nodes_2 is not None: - sim_ln_nodes = sim_ln_nodes_2 - add_node = sim_ln_nodes[3] - pow_node = sim_ln_nodes[-2] - elif sim_ln_nodes_3 is not None: - sim_ln_nodes = sim_ln_nodes_3 - add_node = sim_ln_nodes[3] - pow_node = sim_ln_nodes[-2] - elif sim_ln_nodes_4 is not None: - sim_ln_nodes = sim_ln_nodes_4 - add_node = sim_ln_nodes[3] - pow_node = sim_ln_nodes[-1] - # Verify that parent input to Pow node is graph_input - if pow_node.input[0] not in self.model.get_graphs_input_names(): + if sim_ln_nodes: + mul_node, div_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes + if not self.model.has_constant_input(div_node, 1.0): return - elif sim_ln_nodes_5 is not None: - sim_ln_nodes = sim_ln_nodes_5 - add_node = sim_ln_nodes[3] - pow_node = sim_ln_nodes[-2] else: + # Div(1, RMS) can also be represented as Reciprocal(RMS) like + # + # (root_input) -----------------------------------------------+ + # | | + # v v + # Pow --> ReduceMean --> Add ---> Sqrt --> Reciprocal --> Mul --> Mul (node) + # (B=2) (A/B=eps) (A/B=scale) + # + # (root_input) -----------------------------------------------+ + # | | | + # v v v + # Mul --> ReduceMean --> Add ---> Sqrt --> Reciprocal --> Mul --> Mul (node) + # (B=2) (A/B=eps) (A/B=scale) + # + sim_ln_nodes = self.model.match_parent_path( + node, + ["Mul", "Reciprocal", "Sqrt", "Add", "ReduceMean"], + [None, 1, 0, 0, None], + output_name_to_node=output_name_to_node, + return_indice=return_indice, + ) + if sim_ln_nodes is None: + return + mul_node, _reciprocal_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes + + pow_or_mul_node = self.model.get_parent(reduce_mean_node, 0, output_name_to_node) + if pow_or_mul_node is None or pow_or_mul_node.op_type not in ["Pow", "Mul"]: return - layernorm_weight_index = 1 if sim_ln_nodes in (sim_ln_nodes_3, sim_ln_nodes_4) else 0 - starts_with_graph_input = sim_ln_nodes == sim_ln_nodes_4 + if pow_or_mul_node.op_type == "Pow": + if self.model.find_constant_input(pow_or_mul_node, 2.0) != 1: + return + else: + assert pow_or_mul_node.op_type == "Mul" + if pow_or_mul_node[0] != pow_or_mul_node[1]: + return + + root_input = pow_or_mul_node.input[0] + if root_input != mul_node.input[0]: + return - if self.model.find_constant_input(pow_node, 2.0) != 1: + _i, epsilon = self.model.get_constant_input(add_node) + if epsilon is None or epsilon <= 0 or epsilon > 1.0e-4: + logger.warning(f"epsilon value is not expected: {epsilon}") return - root_input = pow_node.input[0] - if root_input != sim_ln_nodes[0].input[0]: + # ReduceMean must have keepdims == 1 + keepdims = self.model.get_node_attribute(reduce_mean_node, "keepdims") + if not keepdims: return - i, add_weight = self.model.get_constant_input(add_node) - if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: - logger.warning(f"epsilon value is not expected: {add_weight}") + # ReduceMean axes must refer only to the last dimension. + # Axes became an input in opset 18. Before then, axes was an attribute. + axes = self.model.get_node_attribute(reduce_mean_node, "axes") + if (not axes) and len(reduce_mean_node.input) > 1: + axes = self.model.get_constant_value(reduce_mean_node.input[1]) + # Make sure only one axis as required by SimplifiedLayerNormalization spec. + if not axes or len(axes) != 1: return - self.nodes_to_remove.extend(sim_ln_nodes[:-1] if not starts_with_graph_input else sim_ln_nodes) + self.nodes_to_remove.extend(sim_ln_nodes) + self.nodes_to_remove.append(pow_or_mul_node) self.nodes_to_remove.append(node) normalize_node = helper.make_node( "SimplifiedLayerNormalization", - inputs=[root_input, node.input[layernorm_weight_index]], + inputs=[root_input, node.input[1 - return_indice[0]]], outputs=[node.output[0]], - name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="LayerNorm"), + name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="RMSNorm"), ) - normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) - normalize_node.attribute.extend([helper.make_attribute("axis", -1)]) + normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))]) + normalize_node.attribute.extend([helper.make_attribute("axis", axes[0])]) normalize_node.attribute.extend([helper.make_attribute("stash_type", 1)]) self.nodes_to_add.append(normalize_node) self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index dbd9e828198ca..3084b84278994 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -127,6 +127,19 @@ def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes, node_i return parent_can_be_removed + def get_squeeze_or_unsqueeze_axes(self, node: NodeProto) -> Optional[ndarray]: + assert node.op_type in ["Squeeze", "Unsqueeze"] + + # For opset >= 13, axes is an input instead of an attribute. + if len(node.input) > 1: + return self.model.get_constant_value(node.input[1]) + + axes = None + for attr in node.attribute: + if attr.name == "axes": + axes = helper.get_attribute_value(attr) + return axes + @staticmethod def check_node_attribute(node, attribute_name: str, expected_value, default_value=None): """Verify that a node has expected value for an attribute. diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index edef0d3ee5453..dc83f4dc220f0 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -203,35 +203,60 @@ This step will export stable diffusion 1.5 to ONNX model in float32 using script ``` curl https://raw.githubusercontent.com/huggingface/diffusers/v0.15.1/scripts/convert_stable_diffusion_checkpoint_to_onnx.py > convert_sd_onnx.py -python convert_sd_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path ./sd_v1_5/fp32 +python convert_sd_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path ./sd1.5_onnx/fp32 ``` For SDXL, use optimum to export the model: ``` pip install optimum diffusers onnx onnxruntime-gpu -optimum-cli export onnx --model stabilityai/stable-diffusion-xl-base-1.0 --task stable-diffusion-xl ./sd_xl_base_onnx +optimum-cli export onnx --model stabilityai/stable-diffusion-xl-base-1.0 --task stable-diffusion-xl ./sdxl_onnx/fp32 +``` + +#### Stable Diffusion 3.x and Flux 1.0 + +Stable Diffusion 3.x and Flux 1.0 requires transformers >= 4.45, and optimum > 1.23.3. +The default opset version for T5 is 12, which does not support bfloat16. To support bfloat16, please set opset version explicitly like below example. + +``` +git clone https://github.com/huggingface/optimum +cd optimum +pip install -e . + +optimum-cli export onnx --model stabilityai/stable-diffusion-3-medium-diffusers ./sd3_onnx/fp32 --opset 15 +optimum-cli export onnx --model stabilityai/stable-diffusion-3.5-medium ./sd3.5_medium_onnx/fp32 --opset 15 +optimum-cli export onnx --model stabilityai/stable-diffusion-3.5-large ./sd3.5_large_onnx/fp32 --opset 15 +optimum-cli export onnx --model black-forest-labs/FLUX.1-schnell ./flux1_schnell_onnx/fp32 --opset 15 +optimum-cli export onnx --model black-forest-labs/FLUX.1-dev ./flux1_dev_onnx/fp32 --opset 15 ``` ### Optimize ONNX Pipeline -Example to optimize the exported float32 ONNX models, and save to float16 models: +Example to optimize the exported float32 ONNX models, then save to float16 models: ``` -python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i ./sd_v1_5/fp32 -o ./sd_v1_5/fp16 --float16 +python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i ./sd1.5_onnx/fp32 -o ./sd1.5_onnx/fp16 --float16 ``` -In all examples below, we run the scripts in source code directory. You can get source code like the following: +You can also run the script in source code directory like the following: ``` git clone https://github.com/microsoft/onnxruntime cd onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion + +python optimize_pipeline.py -i ./sdxl_onnx/fp32 -o ./sdxl_onnx/fp16 --float16 +python optimize_pipeline.py -i ./sd3_onnx/fp32 -o ./sd3_onnx/fp16 --float16 +python optimize_pipeline.py -i ./sd3.5_medium_onnx/fp32 -o ./sd3.5_medium_onnx/fp16 --float16 +python optimize_pipeline.py -i ./sd3.5_large_onnx/fp32 -o ./sd3.5_large_onnx/fp16 --float16 +python optimize_pipeline.py -i ./flux1_schnell_onnx/fp32 -o ./flux1_schnell_onnx/fp16 --float16 --bfloat16 +python optimize_pipeline.py -i ./flux1_dev_onnx/fp32 -o ./flux1_dev_onnx/fp16 --float16 --bfloat16 ``` +When converting model to float16, some nodes has overflow risk and we can force those nodes to run in either float32 or bfloat16. +Option `--bfloat16` enables the later. If an operator does not support bfloat16, it will fallback to float32. For SDXL model, it is recommended to use a machine with 48 GB or more memory to optimize. -``` -python optimize_pipeline.py -i ./sd_xl_base_onnx -o ./sd_xl_base_fp16 --float16 -``` ### Run Benchmark +#### Run Benchmark with Optimum + The benchmark.py script will run a warm-up prompt twice, and measure the peak GPU memory usage in these two runs, then record them as first_run_memory_MB and second_run_memory_MB. Then it will run 5 runs to get average latency (in seconds), and output the results to benchmark_result.csv. Note that the first run might need more time and memory: For example, cuDNN convolution algorithm search or model compile happens in the first run. @@ -245,15 +270,15 @@ Before running benchmark on PyTorch, you need to be logged in via `huggingface-c Example to benchmark the optimized pipeline of stable diffusion 1.5 with batch size 1 on CUDA EP: ``` -python benchmark.py -p ./sd_v1_5/fp16 -b 1 -v 1.5 +python benchmark.py -p ./sd1.5_onnx/fp16 -b 1 -v 1.5 python benchmark.py -b 1 -v 1.5 ``` For the first command, '-p' specifies a directory of optimized ONNX pipeline as generated by optimize_pipeline.py. -For the second command without '-p', we will use OnnxruntimeCudaStableDiffusionPipeline to export and optimize ONNX models for clip, unet and vae decoder. +For the second command without '-p', we will use ORTPipelineForText2Image to export and optimize ONNX models for clip, unet and vae decoder. On ROCm EP, use the following command instead: ``` -python benchmark.py -p ./sd_v1_5/fp16 -b 1 --tuning --provider rocm -v 1.5 +python benchmark.py -p ./sd1.5_onnx/fp16 -b 1 --tuning --provider rocm -v 1.5 ``` For ROCm EP, you can substitute `python benchmark.py` with `python -m onnxruntime.transformers.models.stable_diffusion.benchmark` since @@ -263,6 +288,22 @@ For ROCm EP, the `--tuning` is mandatory because we heavily rely on tuning to fi The default parameters are stable diffusion version=1.5, height=512, width=512, steps=50, batch_count=5. Run `python benchmark.py --help` for more information. +#### Stable Diffusion 3.x and Flux 1.0 +Example of benchmark with optimum using CUDA provider on stable diffusion 3.5 medium and Flux 1.0: +``` +python benchmark.py -e optimum --height 1024 --width 1024 --steps 30 -b 1 -v 3.0M -p sd3_onnx/fp32 +python benchmark.py -e optimum --height 1024 --width 1024 --steps 30 -b 1 -v 3.5M -p sd3.5_medium_onnx/fp16 +python benchmark.py -e optimum --height 1024 --width 1024 --steps 30 -b 1 -v 3.5L -p sd3.5_large_onnx/fp16 +python benchmark.py -e optimum --height 1024 --width 1024 --steps 4 -b 1 -v Flux.1S -p flux1_schnell_onnx/fp16 +python benchmark.py -e optimum --height 1024 --width 1024 --steps 30 -b 1 -v Flux.1D -p flux1_dev_onnx/fp16 +``` + +Benchmark PyTorch eager mode performance: +``` +python benchmark.py -e torch --height 1024 --width 1024 --steps 30 -b 1 -v 3.5L +python benchmark.py -e torch --height 1024 --width 1024 --steps 30 -b 1 -v Flux.1D +``` + ### Run Benchmark with xFormers Run PyTorch 1.13.1+cu117 with xFormers like the following diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index 0708d57f040f8..0452cff235c11 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -9,6 +9,7 @@ import statistics import sys import time +from pathlib import Path import __init__ # noqa: F401. Walk-around to run this script directly import coloredlogs @@ -22,6 +23,11 @@ "2.0": "stabilityai/stable-diffusion-2", "2.1": "stabilityai/stable-diffusion-2-1", "xl-1.0": "stabilityai/stable-diffusion-xl-refiner-1.0", + "3.0M": "stabilityai/stable-diffusion-3-medium-diffusers", + "3.5M": "stabilityai/stable-diffusion-3.5-medium", + "3.5L": "stabilityai/stable-diffusion-3.5-large", + "Flux.1S": "black-forest-labs/FLUX.1-schnell", + "Flux.1D": "black-forest-labs/FLUX.1-dev", } PROVIDERS = { @@ -90,6 +96,24 @@ def get_ort_pipeline(model_name: str, directory: str, provider, disable_safety_c def get_torch_pipeline(model_name: str, disable_safety_checker: bool, enable_torch_compile: bool, use_xformers: bool): + if "FLUX" in model_name: + from diffusers import FluxPipeline + + pipe = FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda") + if enable_torch_compile: + pipe.transformer.to(memory_format=torch.channels_last) + pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) + return pipe + + if "stable-diffusion-3" in model_name: + from diffusers import StableDiffusion3Pipeline + + pipe = StableDiffusion3Pipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda") + if enable_torch_compile: + pipe.transformer.to(memory_format=torch.channels_last) + pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) + return pipe + from diffusers import DDIMScheduler, StableDiffusionPipeline from torch import channels_last, float16 @@ -116,9 +140,9 @@ def get_torch_pipeline(model_name: str, disable_safety_checker: bool, enable_tor return pipe -def get_image_filename_prefix(engine: str, model_name: str, batch_size: int, disable_safety_checker: bool): +def get_image_filename_prefix(engine: str, model_name: str, batch_size: int, steps: int, disable_safety_checker: bool): short_model_name = model_name.split("/")[-1].replace("stable-diffusion-", "sd") - return f"{engine}_{short_model_name}_b{batch_size}" + ("" if disable_safety_checker else "_safe") + return f"{engine}_{short_model_name}_b{batch_size}_s{steps}" + ("" if disable_safety_checker else "_safe") def run_ort_pipeline( @@ -193,6 +217,25 @@ def warmup(): } +def get_negative_prompt_kwargs(negative_prompt, use_num_images_per_prompt, is_flux, batch_size) -> dict: + # Flux does not support negative prompt + kwargs = ( + ( + {"negative_prompt": negative_prompt} + if use_num_images_per_prompt + else {"negative_prompt": [negative_prompt] * batch_size} + ) + if not is_flux + else {} + ) + + # Fix the random seed so that we can inspect the output quality easily. + if torch.cuda.is_available(): + kwargs["generator"] = torch.Generator(device="cuda").manual_seed(123) + + return kwargs + + def run_torch_pipeline( pipe, batch_size: int, @@ -207,16 +250,14 @@ def run_torch_pipeline( ): prompts, negative_prompt = example_prompts() - # total 2 runs of warm up, and measure GPU memory for CUDA EP + import diffusers + + is_flux = isinstance(pipe, diffusers.FluxPipeline) + def warmup(): prompt, negative = warmup_prompts() - pipe( - prompt=[prompt] * batch_size, - height=height, - width=width, - num_inference_steps=steps, - negative_prompt=[negative] * batch_size, - ) + extra_kwargs = get_negative_prompt_kwargs(negative, False, is_flux, batch_size) + pipe(prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, **extra_kwargs) # Run warm up, and measure GPU memory of two runs (The first run has cuDNN algo search so it might need more memory) first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) @@ -232,13 +273,13 @@ def warmup(): break torch.cuda.synchronize() inference_start = time.time() + extra_kwargs = get_negative_prompt_kwargs(negative_prompt, False, is_flux, batch_size) images = pipe( prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, - negative_prompt=[negative_prompt] * batch_size, - generator=None, # torch.Generator + **extra_kwargs, ).images torch.cuda.synchronize() @@ -289,7 +330,7 @@ def run_ort( load_end = time.time() print(f"Model loading took {load_end - load_start} seconds") - image_filename_prefix = get_image_filename_prefix("ort", model_name, batch_size, disable_safety_checker) + image_filename_prefix = get_image_filename_prefix("ort", model_name, batch_size, steps, disable_safety_checker) result = run_ort_pipeline( pipe, batch_size, @@ -322,33 +363,12 @@ def get_optimum_ort_pipeline( disable_safety_checker: bool = True, use_io_binding: bool = False, ): - from optimum.onnxruntime import ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline + from optimum.onnxruntime import ORTPipelineForText2Image if directory is not None and os.path.exists(directory): - if "xl" in model_name: - pipeline = ORTStableDiffusionXLPipeline.from_pretrained( - directory, - provider=provider, - session_options=None, - use_io_binding=False, # Not supported by Optimum version 1.17.1 at the time of verification. - ) - else: - pipeline = ORTStableDiffusionPipeline.from_pretrained( - directory, - provider=provider, - use_io_binding=use_io_binding, - ) - elif "xl" in model_name: - pipeline = ORTStableDiffusionXLPipeline.from_pretrained( - model_name, - export=True, - provider=provider, - session_options=None, - use_io_binding=False, # Not supported by Optimum version 1.17.1 at the time of verification. - ) - pipeline.save_pretrained(directory) + pipeline = ORTPipelineForText2Image.from_pretrained(directory, provider=provider, use_io_binding=use_io_binding) else: - pipeline = ORTStableDiffusionPipeline.from_pretrained( + pipeline = ORTPipelineForText2Image.from_pretrained( model_name, export=True, provider=provider, @@ -376,31 +396,27 @@ def run_optimum_ort_pipeline( memory_monitor_type, use_num_images_per_prompt=False, ): - from optimum.onnxruntime import ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline + print("Pipeline type", type(pipe)) + from optimum.onnxruntime.modeling_diffusion import ORTFluxPipeline - assert isinstance(pipe, (ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline)) + is_flux = isinstance(pipe, ORTFluxPipeline) prompts, negative_prompt = example_prompts() def warmup(): prompt, negative = warmup_prompts() + extra_kwargs = get_negative_prompt_kwargs(negative, use_num_images_per_prompt, is_flux, batch_size) if use_num_images_per_prompt: pipe( prompt=prompt, height=height, width=width, num_inference_steps=steps, - negative_prompt=negative, num_images_per_prompt=batch_count, + **extra_kwargs, ) else: - pipe( - prompt=[prompt] * batch_size, - height=height, - width=width, - num_inference_steps=steps, - negative_prompt=[negative] * batch_size, - ) + pipe(prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, **extra_kwargs) # Run warm up, and measure GPU memory of two runs. # The first run has algo search for cuDNN/MIOpen, so it might need more memory. @@ -409,6 +425,8 @@ def warmup(): warmup() + extra_kwargs = get_negative_prompt_kwargs(negative_prompt, use_num_images_per_prompt, is_flux, batch_size) + latency_list = [] for i, prompt in enumerate(prompts): if i >= num_prompts: @@ -420,16 +438,12 @@ def warmup(): height=height, width=width, num_inference_steps=steps, - negative_prompt=negative_prompt, num_images_per_prompt=batch_size, + **extra_kwargs, ).images else: images = pipe( - prompt=[prompt] * batch_size, - height=height, - width=width, - num_inference_steps=steps, - negative_prompt=[negative_prompt] * batch_size, + prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, **extra_kwargs ).images inference_end = time.time() latency = inference_end - inference_start @@ -478,7 +492,10 @@ def run_optimum_ort( load_end = time.time() print(f"Model loading took {load_end - load_start} seconds") - image_filename_prefix = get_image_filename_prefix("optimum", model_name, batch_size, disable_safety_checker) + full_model_name = model_name + "_" + Path(directory).name if directory else model_name + image_filename_prefix = get_image_filename_prefix( + "optimum", full_model_name, batch_size, steps, disable_safety_checker + ) result = run_optimum_ort_pipeline( pipe, batch_size, @@ -583,7 +600,7 @@ def warmup(): warmup() - image_filename_prefix = get_image_filename_prefix("ort_trt", short_name, batch_size, disable_safety_checker) + image_filename_prefix = get_image_filename_prefix("ort_trt", short_name, batch_size, steps, disable_safety_checker) latency_list = [] prompts, negative_prompt = example_prompts() @@ -722,7 +739,7 @@ def warmup(): warmup() - image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, disable_safety_checker) + image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, steps, disable_safety_checker) latency_list = [] prompts, negative_prompt = example_prompts() @@ -877,7 +894,7 @@ def warmup(): warmup() model_name = pipeline_info.name() - image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, disable_safety_checker) + image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, steps, disable_safety_checker) latency_list = [] prompts, negative_prompt = example_prompts() @@ -972,7 +989,7 @@ def warmup(): warmup() model_name = pipeline.pipeline_info.name() - image_filename_prefix = get_image_filename_prefix("ort_trt", model_name, batch_size, disable_safety_checker) + image_filename_prefix = get_image_filename_prefix("ort_trt", model_name, batch_size, steps, disable_safety_checker) latency_list = [] prompts, negative_prompt = example_prompts() @@ -1040,7 +1057,7 @@ def run_torch( load_end = time.time() print(f"Model loading took {load_end - load_start} seconds") - image_filename_prefix = get_image_filename_prefix("torch", model_name, batch_size, disable_safety_checker) + image_filename_prefix = get_image_filename_prefix("torch", model_name, batch_size, steps, disable_safety_checker) if not enable_torch_compile: with torch.inference_mode(): diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_flux.sh b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_flux.sh new file mode 100644 index 0000000000000..2c7785eb8f62f --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_flux.sh @@ -0,0 +1,126 @@ +#!/bin/bash +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- + +set -euo pipefail + +# Script to benchmark Flux models with ONNX and PyTorch +# Usage: bash benchmark_flux.sh + +# Validate inputs and environment +command -v python3 &>/dev/null || { echo "Python3 is required but not installed."; exit 1; } +command -v wget &>/dev/null || { echo "wget is required but not installed."; exit 1; } + +# Input arguments with defaults +install_dir="${1:-$HOME}" +onnx_dir="${2:-onnx_models}" + +# GPU settings +export CUDA_VISIBLE_DEVICES=0 + +# Function to log messages +log() { + echo -e "\033[1;32m[INFO]\033[0m $1" +} + +# Function to install CUDA 12.6 +install_cuda_12() { + log "Installing CUDA 12.6" + pushd "$install_dir" + wget -q https://developer.download.nvidia.com/compute/cuda/12.6.2/local_installers/cuda_12.6.2_560.35.03_linux.run + sh cuda_12.6.2_560.35.03_linux.run --toolkit --toolkitpath="$install_dir/cuda12.6" --silent --override --no-man-page + export PATH="$install_dir/cuda12.6/bin:$PATH" + export LD_LIBRARY_PATH="$install_dir/cuda12.6/lib64:$LD_LIBRARY_PATH" + popd +} + +# Function to install cuDNN 9.6 +install_cudnn_9() { + log "Installing cuDNN 9.6" + pushd "$install_dir" + wget -q https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/cudnn-linux-x86_64-9.6.0.74_cuda12-archive.tar.xz + mkdir -p "$install_dir/cudnn9.6" + tar -Jxvf cudnn-linux-x86_64-9.6.0.74_cuda12-archive.tar.xz -C "$install_dir/cudnn9.6" --strip=1 + export LD_LIBRARY_PATH="$install_dir/cudnn9.6/lib:$LD_LIBRARY_PATH" + popd +} + +# Function to install optimum +install_optimum() { + log "Installing Optimum" + optimum_dir="$install_dir/optimum" + if [ ! -d "$optimum_dir" ]; then + git clone https://github.com/huggingface/optimum "$optimum_dir" + fi + pushd "$optimum_dir" + pip show optimum &>/dev/null || pip install -e . + popd +} + +# Function to build and install ONNX Runtime +install_onnxruntime() { + log "Building ONNX Runtime" + pushd "$install_dir" + if [ ! -d onnxruntime ]; then + git clone https://github.com/microsoft/onnxruntime + fi + pushd onnxruntime + pip install --upgrade pip cmake psutil setuptools wheel packaging ninja numpy==2.2 + sh build.sh --config Release --build_dir build/cuda12 --parallel \ + --use_cuda --cuda_version 12.6 --cuda_home "$install_dir/cuda12.6" \ + --cudnn_home "$install_dir/cudnn9.6" \ + --build_wheel --skip_tests \ + --cmake_generator Ninja \ + --compile_no_warning_as_error \ + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF CMAKE_CUDA_ARCHITECTURES=native + + log "Installing ONNX Runtime" + pip install build/cuda12/Release/dist/onnxruntime_gpu-*-linux_x86_64.whl + popd + popd +} + +# Function to install GPU dependencies +install_gpu() { + log "Installing GPU dependencies" + [ ! -d "$install_dir/cuda12.6" ] && install_cuda_12 + [ ! -d "$install_dir/cudnn9.6" ] && install_cudnn_9 + pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 + pip install diffusers==0.32.0 transformers==4.46.3 onnx==1.17.0 protobuf==5.29.2 py3nvml + install_onnxruntime + install_optimum +} + +# Function to run benchmarks +run_benchmark() { + local model=$1 + local dir=$2 + local version=$3 + local steps=$4 + local batch=$5 + + log "Running benchmark for model: $model" + mkdir -p "$dir" + [ ! -d "$dir/fp32" ] && optimum-cli export onnx --model "$model" "$dir/fp32" --opset 15 --task text-to-image + [ ! -d "$dir/fp16_fp32" ] && python optimize_pipeline.py -i "$dir/fp32" -o "$dir/fp16_fp32" --float16 + [ ! -d "$dir/fp16_bf16" ] && python optimize_pipeline.py -i "$dir/fp32" -o "$dir/fp16_bf16" --float16 --bfloat16 + python benchmark.py -e optimum --height 1024 --width 1024 --steps "$steps" -b "$batch" -v "$version" -p "$dir/fp16_fp32" + python benchmark.py -e optimum --height 1024 --width 1024 --steps "$steps" -b "$batch" -v "$version" -p "$dir/fp16_bf16" + python benchmark.py -e torch --height 1024 --width 1024 --steps "$steps" -b "$batch" -v "$version" + python benchmark.py -e torch --height 1024 --width 1024 --steps "$steps" -b "$batch" -v "$version" --enable_torch_compile +} + +# Main script execution +install_gpu + +log "Creating ONNX model directory: $onnx_dir" +mkdir -p "$onnx_dir" + +run_benchmark black-forest-labs/FLUX.1-schnell "$onnx_dir/flux1_schnell" Flux.1S 4 1 > "$onnx_dir/flux1_schnell_s4_b1.log" +run_benchmark black-forest-labs/FLUX.1-dev "$onnx_dir/flux1_dev" Flux.1D 50 1 > "$onnx_dir/flux1_dev_s50_b1.log" +run_benchmark stabilityai/stable-diffusion-3.5-large "$onnx_dir/sd3.5_large" 3.5L 50 1 > "$onnx_dir/sd3.5_large_s50_b1.log" +run_benchmark stabilityai/stable-diffusion-3.5-medium "$onnx_dir/sd3.5_medium" 3.5M 50 1 > "$onnx_dir/sd3.5_medium_s50_b1.log" + +log "Benchmark completed." diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index ffcfd6d9fd7e0..52d332848357f 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -28,6 +28,8 @@ import onnx from fusion_options import FusionOptions from onnx_model_clip import ClipOnnxModel +from onnx_model_mmdit import MmditOnnxModel +from onnx_model_t5 import T5OnnxModel from onnx_model_unet import UnetOnnxModel from onnx_model_vae import VaeOnnxModel from optimizer import optimize_by_onnxruntime, optimize_model @@ -46,11 +48,63 @@ def has_external_data(onnx_model_path): return False +def is_sd_3(source_dir: Path): + return (source_dir / "text_encoder_3").exists() + + +def is_sdxl(source_dir: Path): + return ( + (source_dir / "text_encoder_2").exists() + and not (source_dir / "text_encoder_3").exists() + and not (source_dir / "transformer").exists() + ) + + +def is_flux(source_dir: Path): + return ( + (source_dir / "text_encoder_2").exists() + and not (source_dir / "text_encoder_3").exists() + and (source_dir / "transformer").exists() + ) + + +def _classify_pipeline_type(source_dir: Path): + # May also check _class_name in model_index.json like `StableDiffusion3Pipeline` or `FluxPipeline` etc to classify. + if is_sd_3(source_dir): + return "sd3" + + if is_flux(source_dir): + return "flux" + + if is_sdxl(source_dir): + return "sdxl" + + # sd 1.x and 2.x + return "sd" + + +def _get_model_list(pipeline_type: str): + if pipeline_type == "sd3": + return ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae_encoder", "vae_decoder"] + + if pipeline_type == "flux": + return ["text_encoder", "text_encoder_2", "transformer", "vae_encoder", "vae_decoder"] + + if pipeline_type == "sdxl": + return ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder"] + + assert pipeline_type == "sd" + return ["text_encoder", "unet", "vae_encoder", "vae_decoder"] + + def _optimize_sd_pipeline( source_dir: Path, target_dir: Path, + pipeline_type: str, + model_list: List[str], use_external_data_format: Optional[bool], float16: bool, + bfloat16: bool, force_fp32_ops: List[str], enable_runtime_optimization: bool, args, @@ -60,8 +114,10 @@ def _optimize_sd_pipeline( Args: source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models. target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models. + model_list (List[str]): list of directory names with onnx model. use_external_data_format (Optional[bool]): use external data format. float16 (bool): use half precision + bfloat16 (bool): use bfloat16 as fallback if float16 is also provided. force_fp32_ops(List[str]): operators that are forced to run in float32. enable_runtime_optimization(bool): run graph optimization using Onnx Runtime. @@ -69,12 +125,15 @@ def _optimize_sd_pipeline( RuntimeError: input onnx model does not exist RuntimeError: output onnx model path existed """ + is_flux_pipeline = pipeline_type == "flux" model_type_mapping = { + "transformer": "mmdit", "unet": "unet", "vae_encoder": "vae", "vae_decoder": "vae", "text_encoder": "clip", - "text_encoder_2": "clip", + "text_encoder_2": "t5" if is_flux_pipeline else "clip", + "text_encoder_3": "t5", # t5-v1_1-xxl is used in SD 3.x text_encoder_3 and Flux text_encoder_2. "safety_checker": "unet", } @@ -82,6 +141,8 @@ def _optimize_sd_pipeline( "unet": UnetOnnxModel, "vae": VaeOnnxModel, "clip": ClipOnnxModel, + "t5": T5OnnxModel, + "mmdit": MmditOnnxModel, } force_fp32_operators = { @@ -91,9 +152,140 @@ def _optimize_sd_pipeline( "text_encoder": [], "text_encoder_2": [], "safety_checker": [], + "text_encoder_3": [], + "transformer": [], + } + + # The node block list is generated by running the fp32 model and get statistics of node inputs and outputs. + # Nodes with any input or output of float or double data type, but value ouf of range of float16 are candidates. + # python optimize_pipeline.py -i ./flux1_schnell_onnx/fp32 -o ./flux1_schnell_onnx/fp32_opt + # export ORT_DEBUG_NODE_IO_DUMP_STATISTICS_DATA=1 + # export ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA=1 + # export ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA=1 + # python benchmark.py --height 1024 --width 1024 --steps 4 -b 1 -v Flux.1S -p flux1_schnell_onnx/fp32_opt -e optimum >stdout.txt 2>stderr.txt + # Warning: The node name might change in different export settings. See benchmark_flux.sh for the settings. + flux_node_block_list = { + "text_encoder_2": [ + "/encoder/block.10/layer.1/DenseReluDense/wo/MatMul", + "SkipLayerNorm_20", + "SkipLayerNorm_21", + "SkipLayerNorm_22", + "SkipLayerNorm_23", + "SkipLayerNorm_24", + "SkipLayerNorm_25", + "SkipLayerNorm_26", + "SkipLayerNorm_27", + "SkipLayerNorm_28", + "SkipLayerNorm_29", + "SkipLayerNorm_30", + "SkipLayerNorm_31", + "SkipLayerNorm_32", + "SkipLayerNorm_33", + "SkipLayerNorm_34", + "SkipLayerNorm_35", + "SkipLayerNorm_36", + "SkipLayerNorm_37", + "SkipLayerNorm_38", + "SkipLayerNorm_39", + "SkipLayerNorm_40", + "SkipLayerNorm_41", + "SkipLayerNorm_42", + "SkipLayerNorm_43", + "SkipLayerNorm_44", + "SkipLayerNorm_45", + "/encoder/block.23/layer.1/DenseReluDense/wo/MatMul", + "SkipLayerNorm_46", + ], + "vae_decoder": [ + "/decoder/mid_block/attentions.0/MatMul", + "/decoder/mid_block/attentions.0/Softmax", + ], + "transformer": [ + "/transformer_blocks.18/Mul_5", + "/transformer_blocks.18/Add_7", + "/Concat_1", + "LayerNorm_76", + "/single_transformer_blocks.0/Add", + "LayerNorm_77", + "/single_transformer_blocks.1/Add", + "LayerNorm_78", + "/single_transformer_blocks.2/Add", + "LayerNorm_79", + "/single_transformer_blocks.3/Add", + "LayerNorm_80", + "/single_transformer_blocks.4/Add", + "LayerNorm_81", + "/single_transformer_blocks.5/Add", + "LayerNorm_82", + "/single_transformer_blocks.6/Add", + "LayerNorm_83", + "/single_transformer_blocks.7/Add", + "LayerNorm_84", + "/single_transformer_blocks.8/Add", + "LayerNorm_85", + "/single_transformer_blocks.9/Add", + "LayerNorm_86", + "/single_transformer_blocks.10/Add", + "LayerNorm_87", + "/single_transformer_blocks.11/Add", + "LayerNorm_88", + "/single_transformer_blocks.12/Add", + "LayerNorm_89", + "/single_transformer_blocks.13/Add", + "LayerNorm_90", + "/single_transformer_blocks.14/Add", + "LayerNorm_91", + "/single_transformer_blocks.15/Add", + "LayerNorm_92", + "/single_transformer_blocks.16/Add", + "LayerNorm_93", + "/single_transformer_blocks.17/Add", + "LayerNorm_94", + "/single_transformer_blocks.18/Add", + "LayerNorm_95", + "/single_transformer_blocks.19/Add", + "LayerNorm_96", + "/single_transformer_blocks.20/Add", + "LayerNorm_97", + "/single_transformer_blocks.21/Add", + "LayerNorm_98", + "/single_transformer_blocks.22/Add", + "LayerNorm_99", + "/single_transformer_blocks.23/Add", + "LayerNorm_100", + "/single_transformer_blocks.24/Add", + "LayerNorm_101", + "/single_transformer_blocks.25/Add", + "LayerNorm_102", + "/single_transformer_blocks.26/Add", + "LayerNorm_103", + "/single_transformer_blocks.27/Add", + "LayerNorm_104", + "/single_transformer_blocks.28/Add", + "LayerNorm_105", + "/single_transformer_blocks.29/Add", + "LayerNorm_106", + "/single_transformer_blocks.30/Add", + "LayerNorm_107", + "/single_transformer_blocks.31/Add", + "LayerNorm_108", + "/single_transformer_blocks.32/Add", + "LayerNorm_109", + "/single_transformer_blocks.33/Add", + "LayerNorm_110", + "/single_transformer_blocks.34/Add", + "LayerNorm_111", + "/single_transformer_blocks.35/Add", + "LayerNorm_112", + "/single_transformer_blocks.36/Add", + "LayerNorm_113", + "/single_transformer_blocks.37/Add", + "/Shape", + "/Slice", + ], } - is_xl = (source_dir / "text_encoder_2").exists() + sd3_node_block_list = {"text_encoder_3": flux_node_block_list["text_encoder_2"]} if force_fp32_ops: for fp32_operator in force_fp32_ops: @@ -105,16 +297,21 @@ def _optimize_sd_pipeline( f"--force_fp32_ops shall be in the format of module:operator like unet:Attention, got {fp32_operator}" ) + op_counters = {} for name, model_type in model_type_mapping.items(): onnx_model_path = source_dir / name / "model.onnx" if not os.path.exists(onnx_model_path): - if name != "safety_checker": - logger.info("input onnx model does not exist: %s", onnx_model_path) + if name != "safety_checker" and name in model_list: + logger.warning("input onnx model does not exist: %s", onnx_model_path) # some model are optional so we do not raise error here. continue # Prepare output directory optimized_model_path = target_dir / name / "model.onnx" + if os.path.exists(optimized_model_path): + if not args.overwrite: + logger.warning("Skipped optimization since the target file existed: %s", optimized_model_path) + continue output_dir = optimized_model_path.parent output_dir.mkdir(parents=True, exist_ok=True) @@ -122,7 +319,7 @@ def _optimize_sd_pipeline( use_external_data_format = has_external_data(onnx_model_path) # Graph fusion before fp16 conversion, otherwise they cannot be fused later. - logger.info(f"Optimize {onnx_model_path}...") + logger.info("Optimize %s ...", onnx_model_path) args.model_type = model_type fusion_options = FusionOptions.parse(args) @@ -146,8 +343,28 @@ def _optimize_sd_pipeline( ) if float16: + model_node_block_list = ( + flux_node_block_list if is_flux_pipeline else sd3_node_block_list if pipeline_type == "sd3" else {} + ) + if name in model_node_block_list: + # Opset 12 does not support bfloat16. + # By default, optimum exports T5 model with opset 12. So we need to check the opset version. + use_bfloat16 = bfloat16 + if use_bfloat16: + for opset in m.model.opset_import: + if opset.domain in ["", "ai.onnx"] and opset.version < 13: + logger.warning( + "onnx model requires opset 13 or higher to use bfloat16. Fall back to float32." + ) + use_bfloat16 = False + + m.convert_float_to_float16( + keep_io_types=False, + node_block_list=model_node_block_list[name], + use_bfloat16_as_blocked_nodes_dtype=use_bfloat16, + ) # For SD-XL, use FP16 in VAE decoder will cause NaN and black image so we keep it in FP32. - if is_xl and name == "vae_decoder": + elif pipeline_type in ["sdxl"] and name in ["vae_decoder"]: logger.info("Skip converting %s to float16 to avoid NaN", name) else: logger.info("Convert %s to float16 ...", name) @@ -175,23 +392,26 @@ def _optimize_sd_pipeline( m = model_type_class_mapping[model_type](model) m.get_operator_statistics() - m.get_fused_operator_statistics() + op_counters[name] = m.get_fused_operator_statistics() m.save_model_to_file(str(optimized_model_path), use_external_data_format=use_external_data_format) logger.info("%s is optimized", name) logger.info("*" * 20) + return op_counters + -def _copy_extra_directory(source_dir: Path, target_dir: Path): +def _copy_extra_directory(source_dir: Path, target_dir: Path, model_list: List[str]): """Copy extra directory that does not have onnx model Args: source_dir (Path): source directory target_dir (Path): target directory + model_list (List[str]): list of directory names with onnx model. Raises: RuntimeError: source path does not exist """ - extra_dirs = ["scheduler", "tokenizer", "tokenizer_2", "feature_extractor"] + extra_dirs = ["scheduler", "tokenizer", "tokenizer_2", "tokenizer_3", "feature_extractor"] for name in extra_dirs: source_path = source_dir / name @@ -199,6 +419,8 @@ def _copy_extra_directory(source_dir: Path, target_dir: Path): continue target_path = target_dir / name + if target_path.exists(): + shutil.rmtree(target_path) shutil.copytree(source_path, target_path) logger.info("%s => %s", source_path, target_path) @@ -213,8 +435,7 @@ def _copy_extra_directory(source_dir: Path, target_dir: Path): logger.info("%s => %s", source_path, target_path) # Some directory are optional - onnx_model_dirs = ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder", "safety_checker"] - for onnx_model_dir in onnx_model_dirs: + for onnx_model_dir in model_list: source_path = source_dir / onnx_model_dir / "config.json" target_path = target_dir / onnx_model_dir / "config.json" if source_path.exists(): @@ -235,20 +456,24 @@ def optimize_stable_diffusion_pipeline( if os.path.exists(output_dir): if overwrite: shutil.rmtree(output_dir, ignore_errors=True) - else: - raise RuntimeError("output directory existed:{output_dir}. Add --overwrite to empty the directory.") source_dir = Path(input_dir) target_dir = Path(output_dir) target_dir.mkdir(parents=True, exist_ok=True) - _copy_extra_directory(source_dir, target_dir) + pipeline_type = _classify_pipeline_type(source_dir) + model_list = _get_model_list(pipeline_type) - _optimize_sd_pipeline( + _copy_extra_directory(source_dir, target_dir, model_list) + + return _optimize_sd_pipeline( source_dir, target_dir, + pipeline_type, + model_list, use_external_data_format, float16, + args.bfloat16, args.force_fp32_ops, enable_runtime_optimization, args, @@ -283,10 +508,18 @@ def parse_arguments(argv: Optional[List[str]] = None): "--float16", required=False, action="store_true", - help="Output models of half or mixed precision.", + help="Output models of float16, except some nodes falls back to float32 or bfloat16 to avoid overflow.", ) parser.set_defaults(float16=False) + parser.add_argument( + "--bfloat16", + required=False, + action="store_true", + help="Allow bfloat16 as fallback if --float16 is also provided.", + ) + parser.set_defaults(bfloat16=False) + parser.add_argument( "--force_fp32_ops", required=False, @@ -339,8 +572,11 @@ def parse_arguments(argv: Optional[List[str]] = None): def main(argv: Optional[List[str]] = None): args = parse_arguments(argv) + logger.info("Arguments: %s", str(args)) - optimize_stable_diffusion_pipeline( + + # Return op counters for testing purpose. + return optimize_stable_diffusion_pipeline( args.input, args.output, args.overwrite, args.use_external_data_format, args.float16, args.inspect, args ) diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index fe80a08829263..2a6f9c3d758db 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -233,15 +233,21 @@ def get_nodes_by_op_type(self, op_type): nodes.append(node) return nodes - def get_children(self, node, input_name_to_nodes=None): + def get_children(self, node, input_name_to_nodes=None, output_index=None): if input_name_to_nodes is None: input_name_to_nodes = self.input_name_to_nodes() children = [] - for output in node.output: - if output in input_name_to_nodes: - for node in input_name_to_nodes[output]: - children.append(node) # noqa: PERF402 + if output_index is not None: + if output_index < len(node.output): + output = node.output[output_index] + if output in input_name_to_nodes: + children = list(input_name_to_nodes[output]) + else: + for output in node.output: + if output in input_name_to_nodes: + children.extend(input_name_to_nodes[output]) + return children def get_parents(self, node, output_name_to_node=None): @@ -436,48 +442,63 @@ def match_child_path( self, node, child_op_types, - child_output_index=None, - return_indice=None, + edges: Optional[List[Tuple[int, int]]] = None, + input_name_to_nodes=None, exclude=[], # noqa: B006 ): """ Find a sequence of input edges based on constraints on parent op_type and index. - When input_index is None, we will find the first parent node based on constraints, - and return_indice will be appended the corresponding input index. + Note that we use greedy approach and only consider the first matched child, so it has chance to miss matching. Args: node (str): current node name. child_op_types (str): constraint of child node op_type of each input edge. - child_output_index (list): constraint of input index of each input edge. None means no constraint. - return_indice (list): a list to append the input index - When there is no constraint on input index of an edge. + edges (list): each edge is represented by two integers: output index of parent node, input index of child node. + None means no constraint. + exclude(list): list of nodes that are excluded (not allowed to match as child). Returns: children: a list of matched children node. """ - if child_output_index is not None: - assert len(child_output_index) == len(child_op_types) + if edges is not None: + assert len(edges) == len(child_op_types) + for edge in edges: + assert ( + isinstance(edge, tuple) and len(edge) == 2 and isinstance(edge[0], int) and isinstance(edge[1], int) + ) + + if input_name_to_nodes is None: + input_name_to_nodes = self.input_name_to_nodes() current_node = node matched_children = [] for i, op_type in enumerate(child_op_types): matched_child = None - node_children = self.get_children(current_node) - for child_i, child in enumerate(node_children): + + if edges is None: + children_nodes = self.get_children(current_node, input_name_to_nodes=input_name_to_nodes) + else: + children_nodes = self.get_children( + current_node, input_name_to_nodes=input_name_to_nodes, output_index=edges[i][0] + ) + + for child in children_nodes: if child.op_type == op_type and child not in exclude: - if child_output_index is not None and child_output_index[i] != child_i: - logger.debug( - f"Failed to match index={i} child_output_index={child_output_index[i]} op_type={op_type}", - stack_info=True, - ) - return None + if edges is not None and child.input[edges[i][1]] != current_node.output[edges[i][0]]: + continue + + # Here we use greedy approach and only consider the first matched child. + # TODO: match recursively if we encounter cases that the correct child is not the first matched. matched_child = child + break + if matched_child is None: - logger.debug(f"Failed to match child op_type={op_type}", stack_info=True) + logger.debug(f"Failed to match child {i} op_type={op_type}", stack_info=True) return None matched_children.append(matched_child) current_node = matched_child + return matched_children def find_first_parent_by_type(self, node, parent_type, output_name_to_node=None, recursive=True): diff --git a/onnxruntime/python/tools/transformers/onnx_model_clip.py b/onnxruntime/python/tools/transformers/onnx_model_clip.py index 388d058c7856c..725be3c762e5a 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_clip.py +++ b/onnxruntime/python/tools/transformers/onnx_model_clip.py @@ -27,6 +27,7 @@ def get_fused_operator_statistics(self): "Gelu", "LayerNormalization", "QuickGelu", + "BiasGelu", "SkipLayerNormalization", ] for op in ops: diff --git a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py new file mode 100644 index 0000000000000..4c9b19c0c97ca --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py @@ -0,0 +1,113 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import logging +from typing import Optional + +from fusion_layernorm import FusionLayerNormalization +from fusion_mha_mmdit import FusionMultiHeadAttentionMMDit +from fusion_options import FusionOptions +from import_utils import is_installed +from onnx import ModelProto +from onnx_model_bert import BertOnnxModel + +logger = logging.getLogger(__name__) + + +class MmditOnnxModel(BertOnnxModel): + def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): + """Initialize Multimodal Diffusion Transformer (MMDiT) ONNX Model. + + Args: + model (ModelProto): the ONNX model + num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically). + hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically). + """ + assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0) + super().__init__(model, num_heads=num_heads, hidden_size=hidden_size) + + def postprocess(self): + self.prune_graph() + self.remove_unused_constant() + + def fuse_layer_norm(self): + layernorm_support_broadcast = True + logger.warning( + "The optimized model requires LayerNormalization with broadcast support. " + "Please use onnxruntime-gpu>=1.21 for inference." + ) + fusion = FusionLayerNormalization( + self, check_constant_and_dimension=not layernorm_support_broadcast, force=True + ) + fusion.apply() + + def fuse_multi_head_attention(self): + fusion = FusionMultiHeadAttentionMMDit(self) + fusion.apply() + + def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + assert not add_dynamic_axes + + if is_installed("tqdm"): + import tqdm + from tqdm.contrib.logging import logging_redirect_tqdm + + with logging_redirect_tqdm(): + steps = 5 + progress_bar = tqdm.tqdm(range(steps), initial=0, desc="fusion") + self._optimize(options, progress_bar) + else: + logger.info("tqdm is not installed. Run optimization without progress bar") + self._optimize(options, None) + + def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): + if (options is not None) and not options.enable_shape_inference: + self.disable_shape_inference() + + # Remove cast nodes that having same data type of input and output based on symbolic shape inference. + self.utils.remove_useless_cast_nodes() + if progress_bar: + progress_bar.update(1) + + if (options is None) or options.enable_layer_norm: + self.fuse_layer_norm() + self.fuse_simplified_layer_norm() + if progress_bar: + progress_bar.update(1) + + if (options is None) or options.enable_gelu: + self.fuse_gelu() + if progress_bar: + progress_bar.update(1) + + if (options is None) or options.enable_attention: + self.fuse_multi_head_attention() + if progress_bar: + progress_bar.update(1) + + self.postprocess() + if progress_bar: + progress_bar.update(1) + + logger.info(f"opset version: {self.get_opset_version()}") + + def get_fused_operator_statistics(self): + """ + Returns node count of fused operators. + """ + op_count = {} + ops = [ + "FastGelu", + "MultiHeadAttention", + "LayerNormalization", + "SimplifiedLayerNormalization", + ] + + for op in ops: + nodes = self.get_nodes_by_op_type(op) + op_count[op] = len(nodes) + + logger.info(f"Optimized operators:{op_count}") + return op_count diff --git a/onnxruntime/python/tools/transformers/onnx_model_t5.py b/onnxruntime/python/tools/transformers/onnx_model_t5.py index 9cc4878e8022d..70742bb5f52e3 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_t5.py +++ b/onnxruntime/python/tools/transformers/onnx_model_t5.py @@ -75,9 +75,10 @@ def create_attention_node( k_weight = self.model.get_initializer(k_matmul.input[1]) v_weight = self.model.get_initializer(v_matmul.input[1]) - if q_weight is None: + if q_weight is None or k_weight is None or v_weight is None: + matmul = q_matmul if q_weight is None else k_matmul if k_weight is None else v_matmul print( - f"{q_matmul.input[1]} is not an initializer. " + f"{matmul.input[1]} is not an initializer. " "Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion" ) return None @@ -222,9 +223,7 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no return qkv_nodes = self.model.match_parent_path( - normalize_node, - ["MatMul", "Reshape", "Transpose", "MatMul"], - [1, 0, 0, 0], + normalize_node, ["MatMul", "Reshape", "Transpose", "MatMul"], [1, 0, 0, 0], output_name_to_node ) if qkv_nodes is None: return @@ -235,6 +234,7 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no reshape_qkv, ["Concat", "Unsqueeze", "Gather", "Shape"], [1, 0, 0, 0], + output_name_to_node, ) if qkv_shape_nodes is None: return @@ -244,6 +244,7 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no matmul_qkv, ["Transpose", "Reshape", "MatMul"], [1, 0, 0], + output_name_to_node, ) if v_nodes is None: return @@ -254,28 +255,64 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0], + output_name_to_node, ) if qk_nodes is None: return _, add_qk, matmul_qk = qk_nodes - mask_index = None mask_nodes = self.model.match_parent_path( add_qk, ["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [1, 1, 0, 1, 0, 0], + output_name_to_node, ) + + is_pattern_for_one_graph_input = mask_nodes is None if mask_nodes is None: - return - mul_node = mask_nodes[1] - if mask_nodes[1].op_type != "Mul": - return + # Pattern for SD3 and Flux. + mask_nodes = self.model.match_parent_path( + add_qk, + ["Add", "Slice", "Mul", "Sub", "Unsqueeze", "Unsqueeze"], + [1, 1, 0, 0, 1, 0], + output_name_to_node, + ) + if mask_nodes is None: + return + mul_node = mask_nodes[2] + else: + mul_node = mask_nodes[1] _, mul_val = self.model.get_constant_input(mul_node) - if mul_val != -10000: - self.mask_filter_value = mul_val + if mul_val is None: + return - mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) + if mul_val != -10000: + self.mask_filter_value = float(mul_val) + + # If the mask is derived from shape of input_ids, it means there is no padding mask. + mask_nodes_2 = self.model.match_parent_path( + mask_nodes[-1], + ["ConstantOfShape", "Concat", "Unsqueeze", "Gather", "Shape"], + [0, 0, 0, 0, 0], + output_name_to_node, + ) + mask_nodes_3 = self.model.match_parent_path( + mask_nodes[-1], + ["ConstantOfShape", "Concat", "Unsqueeze", "Gather", "Shape"], + [0, 0, 1, 0, 0], + output_name_to_node, + ) + if ( + mask_nodes_2 is not None + and any(input.name == mask_nodes_2[-1].input[0] for input in self.model.graph().input) + and mask_nodes_3 is not None + and mask_nodes_2[-1].input[0] == mask_nodes_3[-1].input[0] + and len(mask_nodes_2[1].input) == 2 + ): + mask_index = "" + else: + mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) res_pos_bias = None rpb_nodes = self.model.match_parent_path( @@ -283,10 +320,17 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no ["Add", "RelativePositionBias"], [1, 0], ) + if rpb_nodes is None and is_pattern_for_one_graph_input: + # Pattern for SD3 and Flux. + rpb_nodes = self.model.match_parent_path( + add_qk, + ["Add", "Slice", "RelativePositionBias"], + [1, 0, 0], + ) if rpb_nodes is None: return - rpb_add_node = rpb_nodes[0] - res_pos_bias = rpb_add_node.input[0] + + res_pos_bias = rpb_nodes[-1].output[0] k_nodes = self.model.match_parent_path( matmul_qk, @@ -332,13 +376,7 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name - self.nodes_to_remove.extend(qkv_nodes[1:]) - self.nodes_to_remove.extend(qk_nodes) - self.nodes_to_remove.extend(k_nodes[:-1]) - if v_nodes is not None: - self.nodes_to_remove.extend(v_nodes[:-1]) - self.nodes_to_remove.extend(q_nodes[:-1]) - + self.nodes_to_remove.append(reshape_qkv) self.prune_graph = True def fuse_t5_decoder(self, normalize_node, input_name_to_nodes, output_name_to_node): @@ -591,12 +629,7 @@ def fuse_t5_decoder(self, normalize_node, input_name_to_nodes, output_name_to_no self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name - self.nodes_to_remove.extend(qkv_nodes[1:]) - self.nodes_to_remove.extend(qk_nodes) - self.nodes_to_remove.extend(k_nodes[:-1]) - if v_nodes is not None: - self.nodes_to_remove.extend(v_nodes[:-1]) - self.nodes_to_remove.extend(q_nodes[:-1]) + self.nodes_to_remove.append(reshape_qkv) self.prune_graph = True @@ -605,7 +638,6 @@ class FusionRelativePositionBiasBlock(Fusion): def __init__(self, model: OnnxModel, max_distance: int): super().__init__(model, "RelativePositionBias", ["Add", "Slice"]) self.max_distance = max_distance - # bidirectional=(not self.is_decoder) self.is_bidirectional = False def fuse(self, node, input_name_to_nodes, output_name_to_node): @@ -615,11 +647,11 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): return compute_bias_nodes = self.model.match_parent_path( - node, ["Unsqueeze", "Transpose", "Gather", "Where"], [0, 0, 0, 1] + node, ["Unsqueeze", "Transpose", "Gather", "Where"], [0, 0, 0, 1], output_name_to_node ) if compute_bias_nodes is None: compute_bias_nodes = self.model.match_parent_path( - node, ["Unsqueeze", "Transpose", "Gather", "Add", "Where"], [0, 0, 0, 1, 1] + node, ["Unsqueeze", "Transpose", "Gather", "Add", "Where"], [0, 0, 0, 1, 1], output_name_to_node ) if compute_bias_nodes is None: return @@ -632,20 +664,29 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): where, ["Min", "ConstantOfShape", "Shape", "Add", "Cast", "Mul", "Div", "Log", "Div"], [2, 1, 0, 0, 0, 0, 0, 0, 0], + output_name_to_node, ) if compute_buckets_nodes is None: return + # It is possible to deduce max_distance from a Div node: + # The value of self.model.get_constant_value(compute_buckets_nodes[-3].input[1]) is close to + # math.log(max_distance / (relative_attention_num_buckets // (4 if is_bidirectional else 2))) + # See https://github.com/huggingface/transformers/blob/608e163b527eaee41e650ffb9eb4c422d2679902/src/transformers/models/t5/modeling_t5.py#L397. + # Most t5 models use max_distance=128, so we hardcode it unitl we see a model with different value. + # TODO: maybe add a sanity check here. + div = compute_buckets_nodes[-1] range_nodes = self.model.match_parent_path( div, ["Cast", "Neg", "Min", "ConstantOfShape", "Shape", "Sub", "Unsqueeze", "Range"], [0, 0, 0, 1, 0, 0, 0, 0], + output_name_to_node, ) if range_nodes is None: range_nodes = self.model.match_parent_path( - div, ["Cast", "Abs", "Sub", "Unsqueeze", "Range"], [0, 0, 0, 0, 0] + div, ["Cast", "Abs", "Sub", "Unsqueeze", "Range"], [0, 0, 0, 0, 0], output_name_to_node ) self.is_bidirectional = True if range_nodes is None: @@ -653,17 +694,20 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): range_node = range_nodes[-1] - self.nodes_to_remove.extend(compute_bias_nodes) - self.nodes_to_remove.extend(compute_buckets_nodes) - self.nodes_to_remove.extend(range_nodes) + self.nodes_to_remove.append(unsqueeze) + self.prune_graph = True - node_name_prefix = "encoder" if self.is_bidirectional else "decoder" + node_name = self.model.create_node_name( + "RelativePositionBias", name_prefix="RelPosBias_" + ("encoder" if self.is_bidirectional else "decoder") + ) table_weight_i = self.model.get_initializer(gather.input[0]) + if table_weight_i is None: + return table_weight = NumpyHelper.to_array(table_weight_i) table_weight_t = np.transpose(table_weight) bias_table = helper.make_tensor( - name=self.model.create_node_name("bias_table_weight", name_prefix=node_name_prefix), + name=node_name + "_bias_table_weight", data_type=TensorProto.FLOAT, dims=[np.shape(table_weight)[0], np.shape(table_weight)[1]], vals=table_weight_t.tobytes(), @@ -677,7 +721,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): "RelativePositionBias", inputs=inputs, outputs=outputs, - name=self.model.create_node_name("RelativePositionBias", name_prefix=node_name_prefix), + name=node_name, ) rpb_node.domain = "com.microsoft" rpb_node.attribute.extend([helper.make_attribute("max_distance", self.max_distance)]) @@ -688,14 +732,19 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): class T5OnnxModel(BertOnnxModel): - def __init__(self, model, num_heads, hidden_size): + def __init__(self, model, num_heads: int = 0, hidden_size: int = 0): super().__init__(model, num_heads, hidden_size) self.attention_mask = AttentionMask(self) + + # When the model has only one input (input_ids), there is no padding mask. + if len(self.model.graph.input) == 1: + from fusion_options import AttentionMaskFormat + + self.attention_mask.mask_format = AttentionMaskFormat.NoMask + self.attention_fusion = FusionT5Attention(self, self.hidden_size, self.num_heads, self.attention_mask) self.layer_norm_fusion = FusionSimplifiedLayerNormalization(self) self.skip_layer_norm_fusion = FusionSkipSimplifiedLayerNormalization(self) - # TODO: consider retrieve max_distance from model. - # math.log(max_distance / (num_buckets // 2)) self.rpb_fusion = FusionRelativePositionBiasBlock(self, 128) def fuse_attention(self): @@ -704,9 +753,65 @@ def fuse_attention(self): def fuse_layer_norm(self): self.layer_norm_fusion.apply() - def fuse_skip_layer_norm(self): + def fuse_skip_layer_norm(self, shape_infer=True): self.skip_layer_norm_fusion.apply() + def adjust_rel_pos_bis_length_input(self): + # For T5 encoder, it uses complex logic to compute the query and key length when there is only one graph input (input_ids) + # We can directly get the length from shape (the 2nd dimension) of input_ids. + for node in self.nodes(): + if node.op_type == "RelativePositionBias": + nodes = self.match_parent_path( + node, + [ + "Gather", + "Shape", + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "SimplifiedLayerNormalization", + "Gather", + ], + [1, 0, 0, 0, 1, 0, 0, 0, 0, 0], + ) + # TODO: more validation on node attributes + if nodes is not None: + graph_input_names = [input.name for input in self.model.graph.input] + if nodes[-1].input[1] in graph_input_names: + node_name = self.create_node_name("Shape", name_prefix="Added_Shape_") + shape_node = helper.make_node( + "Shape", + inputs=[nodes[-1].input[1]], + outputs=[node_name + "_Output"], + name=node_name, + ) + + indices_1 = helper.make_tensor( + name="Constant_Index_1", + data_type=TensorProto.INT64, + dims=[1], # Shape of the tensor + vals=[1], # Tensor values + ) + self.add_initializer(indices_1) + + gather = helper.make_node( + "Gather", + inputs=[node_name + "_Output", "Constant_Index_1"], + outputs=[node_name + "_Output_Gather_1"], + name=self.create_node_name("Gather", name_prefix="Added_Gather_"), + axis=0, + ) + + self.add_node(shape_node) + self.add_node(gather) + node.input[1] = node_name + "_Output_Gather_1" + node.input[2] = node_name + "_Output_Gather_1" + + break + # Remove get_extended_attention_mask() since it generates all zeros. def remove_extended_mask_decoder_init(self): nodes_to_remove = [] @@ -787,5 +892,6 @@ def postprocess(self): # remove get_extended_attention_mask() since it generates all zeros. self.remove_extended_mask_decoder_init() self.remove_extended_mask_decoder() + self.adjust_rel_pos_bis_length_input() self.prune_graph() diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 933bd785dc00d..a83c54e345d7d 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -35,6 +35,7 @@ from onnx_model_clip import ClipOnnxModel from onnx_model_conformer import ConformerOnnxModel from onnx_model_gpt2 import Gpt2OnnxModel +from onnx_model_mmdit import MmditOnnxModel from onnx_model_phi import PhiOnnxModel from onnx_model_sam2 import Sam2OnnxModel from onnx_model_t5 import T5OnnxModel @@ -66,6 +67,7 @@ "unet": (UnetOnnxModel, "pytorch", 1), # UNet in Stable Diffusion "vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion "vit": (BertOnnxModel, "pytorch", 1), + "mmdit": (MmditOnnxModel, "pytorch", 1), } @@ -237,7 +239,9 @@ def optimize_by_fusion( Returns: object of an optimizer class. """ - if model_type not in ["bert", "swin", "unet", "vae", "clip", "sam2"] and (num_heads == 0 or hidden_size == 0): + if model_type not in ["bert", "t5", "swin", "unet", "vae", "clip", "sam2", "mmdit"] and ( + num_heads == 0 or hidden_size == 0 + ): logger.warning(f"Please specify parameters of num_heads and hidden_size for model_type {model_type}") if model_type not in MODEL_TYPES: diff --git a/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py b/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py index dca250f39fae2..692382a12da9f 100644 --- a/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py +++ b/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py @@ -29,6 +29,8 @@ TINY_MODELS = { "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", + "stable-diffusion-3": "optimum-internal-testing/tiny-random-stable-diffusion-3", + "flux": "tlwu/tiny-random-flux", } @@ -114,157 +116,287 @@ def test_clip_sd(self): float16=True, ) - @pytest.mark.slow - def test_clip_sdxl(self): - save_directory = "tiny-random-stable-diffusion-xl" - if os.path.exists(save_directory): - shutil.rmtree(save_directory, ignore_errors=True) - - model_type = "stable-diffusion-xl" - model_name = TINY_MODELS[model_type] - - from optimum.onnxruntime import ORTStableDiffusionXLPipeline - - base = ORTStableDiffusionXLPipeline.from_pretrained(model_name, export=True) - base.save_pretrained(save_directory) - - clip_onnx_path = os.path.join(save_directory, "text_encoder", "model.onnx") - optimized_clip_onnx_path = os.path.join(save_directory, "text_encoder", "opt.onnx") - self.verify_clip_optimizer( - clip_onnx_path, - optimized_clip_onnx_path, - expected_counters={ - "EmbedLayerNormalization": 0, - "Attention": 5, - "SkipLayerNormalization": 10, - "LayerNormalization": 1, - "Gelu": 0, - "BiasGelu": 5, - }, - ) - - clip_onnx_path = os.path.join(save_directory, "text_encoder_2", "model.onnx") - optimized_clip_onnx_path = os.path.join(save_directory, "text_encoder_2", "opt.onnx") - self.verify_clip_optimizer( - clip_onnx_path, - optimized_clip_onnx_path, - expected_counters={ - "EmbedLayerNormalization": 0, - "Attention": 5, - "SkipLayerNormalization": 10, - "LayerNormalization": 1, - "Gelu": 0, - "BiasGelu": 5, - }, - ) - - @pytest.mark.slow - def test_optimize_sdxl_fp32(self): - save_directory = "tiny-random-stable-diffusion-xl" - if os.path.exists(save_directory): - shutil.rmtree(save_directory, ignore_errors=True) - model_type = "stable-diffusion-xl" - model_name = TINY_MODELS[model_type] - - from optimum.onnxruntime import ORTStableDiffusionXLPipeline +class TestStableDiffusionOrFluxPipelineOptimization(unittest.TestCase): + def verify_pipeline_optimization( + self, + model_name, + export_onnx_dir, + optimized_onnx_dir, + expected_op_counters, + is_float16, + atol, + disable_group_norm=False, + ): + from optimum.onnxruntime import ORTPipelineForText2Image - baseline = ORTStableDiffusionXLPipeline.from_pretrained(model_name, export=True) - if not os.path.exists(save_directory): - baseline.save_pretrained(save_directory) + if os.path.exists(export_onnx_dir): + shutil.rmtree(export_onnx_dir, ignore_errors=True) - batch_size, num_images_per_prompt, height, width = 2, 2, 64, 64 - latents = baseline.prepare_latents( - batch_size * num_images_per_prompt, - baseline.unet.config["in_channels"], - height, - width, - dtype=np.float32, - generator=np.random.RandomState(0), - ) + baseline = ORTPipelineForText2Image.from_pretrained(model_name, export=True, provider="CUDAExecutionProvider") + if not os.path.exists(export_onnx_dir): + baseline.save_pretrained(export_onnx_dir) - optimized_directory = "tiny-random-stable-diffusion-xl-optimized" argv = [ "--input", - save_directory, + export_onnx_dir, "--output", - optimized_directory, - "--disable_group_norm", - "--disable_bias_splitgelu", + optimized_onnx_dir, "--overwrite", + "--disable_bias_splitgelu", ] - optimize_stable_diffusion(argv) - treatment = ORTStableDiffusionXLPipeline.from_pretrained(optimized_directory, provider="CUDAExecutionProvider") + if disable_group_norm: + argv.append("--disable_group_norm") + + if is_float16: + argv.append("--float16") + + op_counters = optimize_stable_diffusion(argv) + print(op_counters) + + for name in expected_op_counters: + self.assertIn(name, op_counters) + for op, count in expected_op_counters[name].items(): + self.assertIn(op, op_counters[name]) + self.assertEqual(op_counters[name][op], count, f"Expected {count} {op} in {name}") + + treatment = ORTPipelineForText2Image.from_pretrained(optimized_onnx_dir, provider="CUDAExecutionProvider") + batch_size, num_images_per_prompt, height, width = 1, 1, 64, 64 inputs = { "prompt": ["starry night by van gogh"] * batch_size, - "num_inference_steps": 3, + "num_inference_steps": 20, "num_images_per_prompt": num_images_per_prompt, "height": height, "width": width, - "guidance_rescale": 0.1, "output_type": "np", } - ort_outputs_1 = baseline(latents=latents, **inputs) - ort_outputs_2 = treatment(latents=latents, **inputs) - self.assertTrue(np.allclose(ort_outputs_1.images[0], ort_outputs_2.images[0], atol=1e-3)) + seed = 123 + np.random.seed(seed) + import torch + + baseline_outputs = baseline(**inputs, generator=torch.Generator(device="cuda").manual_seed(seed)) + + np.random.seed(seed) + treatment_outputs = treatment(**inputs, generator=torch.Generator(device="cuda").manual_seed(seed)) + + self.assertTrue(np.allclose(baseline_outputs.images[0], treatment_outputs.images[0], atol=atol)) @pytest.mark.slow - def test_optimize_sdxl_fp16(self): - """This tests optimized fp16 pipeline, and result is deterministic for a given seed""" - save_directory = "tiny-random-stable-diffusion-xl" - if os.path.exists(save_directory): - shutil.rmtree(save_directory, ignore_errors=True) + def test_sd(self): + """This tests optimization of stable diffusion 1.x pipeline""" + model_name = TINY_MODELS["stable-diffusion"] + + expected_op_counters = { + "unet": { + "Attention": 6, + "MultiHeadAttention": 6, + "LayerNormalization": 6, + "SkipLayerNormalization": 12, + "BiasSplitGelu": 0, + "GroupNorm": 0, + "SkipGroupNorm": 0, + "NhwcConv": 47, + "BiasAdd": 0, + }, + "vae_encoder": {"Attention": 0, "GroupNorm": 0, "SkipGroupNorm": 0, "NhwcConv": 13}, + "vae_decoder": {"Attention": 0, "GroupNorm": 0, "SkipGroupNorm": 0, "NhwcConv": 17}, + "text_encoder": { + "Attention": 5, + "Gelu": 0, + "LayerNormalization": 1, + "QuickGelu": 5, + "BiasGelu": 0, + "SkipLayerNormalization": 10, + }, + } - model_type = "stable-diffusion-xl" - model_name = TINY_MODELS[model_type] + export_onnx_dir = "tiny-random-sd" + optimized_onnx_dir = "tiny-random-sd-optimized-fp32" + # Disable GroupNorm due to limitation of current cuda kernel implementation. + self.verify_pipeline_optimization( + model_name, + export_onnx_dir, + optimized_onnx_dir, + expected_op_counters, + is_float16=False, + atol=5e-3, + disable_group_norm=True, + ) - from optimum.onnxruntime import ORTStableDiffusionXLPipeline + expected_op_counters["unet"].update({"Attention": 0, "MultiHeadAttention": 12}) + optimized_onnx_dir = "tiny-random-sd-optimized-fp16" + self.verify_pipeline_optimization( + model_name, + export_onnx_dir, + optimized_onnx_dir, + expected_op_counters, + is_float16=True, + atol=5e-2, + disable_group_norm=True, + ) - baseline = ORTStableDiffusionXLPipeline.from_pretrained(model_name, export=True) - if not os.path.exists(save_directory): - baseline.save_pretrained(save_directory) + @pytest.mark.slow + def test_sdxl(self): + """This tests optimization of SDXL pipeline""" + model_name = TINY_MODELS["stable-diffusion-xl"] + + expected_op_counters = { + "unet": { + "Attention": 12, + "MultiHeadAttention": 12, + "LayerNormalization": 6, + "SkipLayerNormalization": 30, + "BiasSplitGelu": 0, + "GroupNorm": 0, + "SkipGroupNorm": 0, + "NhwcConv": 35, + "BiasAdd": 0, + }, + "vae_encoder": {"Attention": 0, "GroupNorm": 0, "SkipGroupNorm": 0, "NhwcConv": 13}, + "vae_decoder": {"Attention": 0, "GroupNorm": 0, "SkipGroupNorm": 0, "NhwcConv": 17}, + "text_encoder": { + "Attention": 5, + "Gelu": 0, + "LayerNormalization": 1, + "QuickGelu": 0, + "BiasGelu": 5, + "SkipLayerNormalization": 10, + }, + "text_encoder_2": { + "Attention": 5, + "Gelu": 0, + "LayerNormalization": 1, + "QuickGelu": 0, + "BiasGelu": 5, + "SkipLayerNormalization": 10, + }, + } - optimized_directory = "tiny-random-stable-diffusion-xl-optimized-fp16" - argv = [ - "--input", - save_directory, - "--output", - optimized_directory, - "--disable_group_norm", - "--disable_bias_splitgelu", - "--float16", - "--overwrite", - ] - optimize_stable_diffusion(argv) + export_onnx_dir = "tiny-random-sdxl" + optimized_onnx_dir = "tiny-random-sdxl-optimized-fp32" + # Disable GroupNorm due to limitation of current cuda kernel implementation. + self.verify_pipeline_optimization( + model_name, + export_onnx_dir, + optimized_onnx_dir, + expected_op_counters, + is_float16=False, + atol=5e-3, + disable_group_norm=True, + ) - fp16_pipeline = ORTStableDiffusionXLPipeline.from_pretrained( - optimized_directory, provider="CUDAExecutionProvider" + expected_op_counters["unet"].update({"Attention": 0, "MultiHeadAttention": 24}) + optimized_onnx_dir = "tiny-random-sdxl-optimized-fp16" + self.verify_pipeline_optimization( + model_name, + export_onnx_dir, + optimized_onnx_dir, + expected_op_counters, + is_float16=True, + atol=5e-2, + disable_group_norm=True, ) - batch_size, num_images_per_prompt, height, width = 1, 1, 64, 64 - inputs = { - "prompt": ["starry night by van gogh"] * batch_size, - "num_inference_steps": 3, - "num_images_per_prompt": num_images_per_prompt, - "height": height, - "width": width, - "guidance_rescale": 0.1, - "output_type": "latent", + + @pytest.mark.slow + def test_sd3(self): + """This tests optimization of stable diffusion 3 pipeline""" + model_name = TINY_MODELS["stable-diffusion-3"] + + expected_op_counters = { + "transformer": { + "FastGelu": 3, + "MultiHeadAttention": 2, + "LayerNormalization": 8, + "SimplifiedLayerNormalization": 0, + }, + "vae_encoder": {"Attention": 0, "GroupNorm": 10, "SkipGroupNorm": 3, "NhwcConv": 17}, + "vae_decoder": {"Attention": 0, "GroupNorm": 14, "SkipGroupNorm": 7, "NhwcConv": 25}, + "text_encoder": { + "Attention": 2, + "Gelu": 0, + "LayerNormalization": 1, + "QuickGelu": 2, + "SkipLayerNormalization": 4, + }, + "text_encoder_2": { + "Attention": 2, + "Gelu": 0, + "LayerNormalization": 1, + "QuickGelu": 0, + "SkipLayerNormalization": 4, + }, + "text_encoder_3": { + "Attention": 2, + "MultiHeadAttention": 0, + "Gelu": 0, + "FastGelu": 2, + "BiasGelu": 0, + "GemmFastGelu": 0, + "LayerNormalization": 0, + "SimplifiedLayerNormalization": 2, + "SkipLayerNormalization": 0, + "SkipSimplifiedLayerNormalization": 3, + }, } - seed = 123 - np.random.seed(seed) - ort_outputs_1 = fp16_pipeline(**inputs) + export_onnx_dir = "tiny-random-stable-diffusion-3" + optimized_onnx_dir = "tiny-random-stable-diffusion-3-optimized-fp32" + self.verify_pipeline_optimization( + model_name, export_onnx_dir, optimized_onnx_dir, expected_op_counters, is_float16=False, atol=5e-3 + ) - np.random.seed(seed) - ort_outputs_2 = fp16_pipeline(**inputs) + optimized_onnx_dir = "tiny-random-stable-diffusion-3-optimized-fp16" + self.verify_pipeline_optimization( + model_name, export_onnx_dir, optimized_onnx_dir, expected_op_counters, is_float16=True, atol=5e-2 + ) - np.random.seed(seed) - ort_outputs_3 = fp16_pipeline(**inputs) + @pytest.mark.slow + def test_flux(self): + """This tests optimization of flux pipeline""" + model_name = TINY_MODELS["flux"] + + expected_op_counters = { + "transformer": { + "FastGelu": 8, + "MultiHeadAttention": 6, + "LayerNormalization": 13, + "SimplifiedLayerNormalization": 16, + }, + "vae_encoder": {"Attention": 0, "GroupNorm": 10, "SkipGroupNorm": 3, "NhwcConv": 17}, + "vae_decoder": {"Attention": 0, "GroupNorm": 14, "SkipGroupNorm": 7, "NhwcConv": 25}, + "text_encoder": { + "Attention": 2, + "Gelu": 0, + "LayerNormalization": 1, + "QuickGelu": 2, + "SkipLayerNormalization": 4, + }, + "text_encoder_2": { + "Attention": 2, + "MultiHeadAttention": 0, + "Gelu": 0, + "FastGelu": 2, + "BiasGelu": 0, + "GemmFastGelu": 0, + "LayerNormalization": 0, + "SimplifiedLayerNormalization": 2, + "SkipLayerNormalization": 0, + "SkipSimplifiedLayerNormalization": 3, + }, + } - self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) - self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) + export_onnx_dir = "tiny-random-flux" + optimized_onnx_dir = "tiny-random-flux-optimized-fp32" + self.verify_pipeline_optimization( + model_name, export_onnx_dir, optimized_onnx_dir, expected_op_counters, is_float16=False, atol=1e-3 + ) + + optimized_onnx_dir = "tiny-random-flux-optimized-fp16" + self.verify_pipeline_optimization( + model_name, export_onnx_dir, optimized_onnx_dir, expected_op_counters, is_float16=True, atol=5e-2 + ) if __name__ == "__main__":