Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable Diffusion 3.x and Flux Optimization #22986

Merged
merged 28 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6fb7369
initial
tianleiwu Nov 22, 2024
9b2dcc0
sd3.x and flux
tianleiwu Dec 2, 2024
7f925ce
update FastGelu and RMSNorm fusions
tianleiwu Dec 5, 2024
cf259e1
support Reciprocal in RMSNorm fusion
tianleiwu Dec 6, 2024
b38f12e
match_child_path interface change
tianleiwu Dec 13, 2024
a58b68c
clean up
tianleiwu Dec 13, 2024
c7317cb
MHA fusion for MMDit
tianleiwu Dec 13, 2024
2f5b9b9
cuda layernorm support broadcast
tianleiwu Dec 15, 2024
699a64c
force fuse layernorm
tianleiwu Dec 15, 2024
c1d0160
refactoring
tianleiwu Dec 15, 2024
1b9ea54
mha fusion for flux
tianleiwu Dec 19, 2024
5528276
remove transpose for query
tianleiwu Dec 20, 2024
89950d1
t5 optimization and mixed precision conversion
tianleiwu Dec 23, 2024
c869151
fix node name
tianleiwu Dec 23, 2024
84b1a51
Add option to use bfloat16
tianleiwu Dec 24, 2024
b7041d1
fix attention
tianleiwu Dec 25, 2024
455a3ea
update node block list of t5 encoder
tianleiwu Dec 25, 2024
dad0ac4
benchmark torch eager mode
tianleiwu Dec 25, 2024
8400558
update comment
tianleiwu Dec 25, 2024
9e43e20
benchmark torch compile
tianleiwu Dec 26, 2024
4bf9f25
refine benchmark_flux.sh
tianleiwu Dec 26, 2024
913c6ed
Merge branch 'main' into tlwu/sd3_optimum
tianleiwu Jan 6, 2025
a47b6af
undo layer norm kernel
tianleiwu Jan 10, 2025
55178d6
CMAKE_CUDA_ARCHITECTURES=native
tianleiwu Jan 11, 2025
dac8ea7
Merge branch 'main' into tlwu/sd3_optimum
tianleiwu Jan 11, 2025
ebade48
add tests
tianleiwu Jan 12, 2025
fd227bb
update tests
tianleiwu Jan 12, 2025
87bd3ec
undo some change (move to another PR)
tianleiwu Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
(double)epsilon_, // epsilon
reinterpret_cast<const CudaT*>(gamma->Data<T>()), // gamma
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr, // beta
0, // broadcast stride for gamma/beta
reinterpret_cast<const CudaT*>(skip->Data<T>()), // skip or residual to add
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, // bias to add
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr);
Expand Down
33 changes: 25 additions & 8 deletions onnxruntime/core/providers/cuda/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,36 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
auto bias_data = (simplified || (nullptr == bias)) ? nullptr : reinterpret_cast<const CudaV*>(bias->Data<V>());

const TensorShape& x_shape = X->Shape();
const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions());
auto x_num_dims = x_shape.NumDimensions();
const int64_t axis = HandleNegativeAxis(axis_, x_num_dims);

int n1 = gsl::narrow<int>(x_shape.SizeToDimension(axis));
int n2 = gsl::narrow<int>(x_shape.SizeFromDimension(axis));

const auto scale_size = scale->Shape().Size();
const auto bias_size = (bias_data) ? bias->Shape().Size() : 0;

int broadcast = 0;
if (n2 == 1 || scale_size != n2 || (bias_data && bias_size != n2)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", n2,
". Size of scale and bias (if provided) must match this "
"and the size must not be 1. Got scale size of ",
scale_size, " and bias size of ", bias_size);
// Handle a special case for MMDit where scale and bias need broadcast.
// X shape is (B, S, D), scale and bias shape is (B, 1, D), and we store S as broadcast stride.
if (x_num_dims == 3 && axis == 2 && n2 > 1 &&
scale->Shape().NumDimensions() == x_num_dims &&
scale->Shape().GetDims()[0] == x_shape.GetDims()[0] &&
scale->Shape().GetDims()[1] == 1 &&
scale->Shape().GetDims()[2] == x_shape.GetDims()[2] &&
bias->Shape().NumDimensions() == x_num_dims &&
bias->Shape().GetDims()[0] == x_shape.GetDims()[0] &&
bias->Shape().GetDims()[1] == 1 &&
bias->Shape().GetDims()[2] == x_shape.GetDims()[2]) {
broadcast = static_cast<int>(x_shape.GetDims()[1]);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", n2,
". Size of scale and bias (if provided) must match this "
"and the size must not be 1. Got scale size of ",
scale_size, " and bias size of ", bias_size);
}
}

// Outputs
Expand All @@ -65,7 +82,7 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con

// Mean and variance
std::vector<int64_t> mean_inv_std_var_dim;
for (int i = 0; i < static_cast<int>(x_shape.NumDimensions()); ++i) {
for (int i = 0; i < static_cast<int>(x_num_dims); ++i) {
if (i < axis) {
mean_inv_std_var_dim.emplace_back(x_shape.GetDims()[i]);
} else {
Expand Down Expand Up @@ -94,7 +111,7 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
}

HostApplyLayerNorm<CudaT, CudaU, CudaV, simplified>(GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data,
X_data, n1, n2, epsilon_, scale_data, bias_data);
X_data, n1, n2, epsilon_, scale_data, bias_data, broadcast);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
return Status::OK();
}
Expand Down
17 changes: 12 additions & 5 deletions onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ __global__ void cuApplyLayerNorm(
const U epsilon,
const V* __restrict__ gamma,
const V* __restrict__ beta,
int broadcast,
const T* __restrict__ skip,
const T* __restrict__ bias,
T* __restrict__ skip_input_bias_add_output) {
Expand Down Expand Up @@ -366,8 +367,13 @@ __global__ void cuApplyLayerNorm(
curr += static_cast<U>(skip_vals[i]);
}

U gamma_i = (gamma != nullptr) ? (U)gamma[i] : (U)1;
U beta_i = (beta != nullptr) ? (U)beta[i] : (U)0;
// onnx operator LayerNormalization support broadcast.
// gamma and beta should be unidirectional broadcastable to tensor x.
// Here we support a special case for transformer models that x is (B, S, D) and gamma/beta is (B, 1, D)
int index = (broadcast > 0) ? ((i1 / broadcast) * n2 + i) : i;
U gamma_i = (gamma != nullptr) ? (U)gamma[index] : (U)1;
U beta_i = (beta != nullptr) ? (U)beta[index] : (U)0;

if (simplified) {
ovals[i] = static_cast<V>(gamma_i * c_inv_std_dev * curr);
} else {
Expand Down Expand Up @@ -409,6 +415,7 @@ void HostApplyLayerNorm(
double epsilon,
const V* gamma,
const V* beta,
int broadcast,
const T* skip,
const T* bias,
T* skip_input_bias_add_output) {
Expand Down Expand Up @@ -442,15 +449,15 @@ void HostApplyLayerNorm(
input,
n1, n2,
U(epsilon),
gamma, beta,
gamma, beta, broadcast,
skip, bias, skip_input_bias_add_output);
}

#define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \
template void HostApplyLayerNorm<T, U, V, simplified>(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \
U* mean, U* inv_std_dev, const T* input, int n1, int n2, \
double epsilon, const V* gamma, const V* beta, const T* skip, \
const T* bias, T* skip_input_bias_add_output);
double epsilon, const V* gamma, const V* beta, int broadcast, \
const T* skip, const T* bias, T* skip_input_bias_add_output);

LAYERNORM_LINEAR_IMPL(float, float, float, true)
LAYERNORM_LINEAR_IMPL(half, float, half, true)
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cuda/nn/layer_norm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ void HostApplyLayerNorm(
double epsilon,
const V* gamma,
const V* beta,
int broadcast = 0, // broadcast stride for gamma/beta
const T* skip = nullptr,
const T* bias = nullptr,
T* skip_input_bias_add_output = nullptr);
Expand Down
39 changes: 0 additions & 39 deletions onnxruntime/python/tools/transformers/fusion_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,45 +399,6 @@ def split_kv(self, present_k_name: str, present_v_name: str, kv_node: str):
self.node_name_to_graph_name[gather_k_name] = self.this_graph_name
self.node_name_to_graph_name[gather_v_name] = self.this_graph_name

def transpose_kv(self, past_k: str, past_v: str):
"""Transpose past_k and past_v from (B,N,P,H) to (B,P,N,H)

Args:
past_k (str): name of past K value of shape (B,N,P,H)
past_v (str): name of past V value of shape (B,N,P,H)

Returns:
past_k_transpose (str): name of past K value of shape (B,P,N,H)
past_v_transpose (str): name of past V value of shape (B,P,N,H)
"""
past_k_transpose = (past_k + "_transposed").replace(".", "_")
past_v_transpose = (past_v + "_transposed").replace(".", "_")
transpose_k_name = self.model.create_node_name("Transpose")
transpose_v_name = self.model.create_node_name("Transpose")

transpose_k = helper.make_node(
"Transpose",
inputs=[past_k],
outputs=[past_k_transpose],
name=transpose_k_name,
perm=[0, 2, 1, 3],
)
transpose_v = helper.make_node(
"Transpose",
inputs=[past_v],
outputs=[past_v_transpose],
name=transpose_v_name,
perm=[0, 2, 1, 3],
)

# Add reshape nodes to graph
self.nodes_to_add.append(transpose_k)
self.nodes_to_add.append(transpose_v)
self.node_name_to_graph_name[transpose_k_name] = self.this_graph_name
self.node_name_to_graph_name[transpose_v_name] = self.this_graph_name

return past_k_transpose, past_v_transpose

def create_combined_qkv_bias(
self,
q_add: NodeProto,
Expand Down
122 changes: 122 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_fastgelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
if self.fuse_3(tanh_node, input_name_to_nodes, output_name_to_node):
return

if self.fuse_4(tanh_node, input_name_to_nodes, output_name_to_node):
return

def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optional[bool]:
"""
Fuse Gelu with tanh into one node:
Expand Down Expand Up @@ -358,3 +361,122 @@
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
return True

def fuse_4(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
Dismissed Show dismissed Hide dismissed
"""
This pattern is from stable diffusion 3.5 model.
Fuse Gelu with tanh into one node:
+-----------------+------------------+
| | |
| v v
[root] ==> Mul --> Mul --> Mul -----> Add --> Mul --> Tanh --> Add -----> Mul --> Mul -->
| (A=0.0447) (A=0.7978) (A=1) ^ (A=0.5)
| |
+-------------------------------------------------------------------------+
Note that constant input for Add and Mul could be first or second input.
"""
if tanh_node.output[0] not in input_name_to_nodes:
return

children = input_name_to_nodes[tanh_node.output[0]]
if len(children) != 1 or children[0].op_type != "Add":
return
add_after_tanh = children[0]

if not self.model.has_constant_input(add_after_tanh, 1.0):
return

if add_after_tanh.output[0] not in input_name_to_nodes:
return
children = input_name_to_nodes[add_after_tanh.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return
mul_after_tanh = children[0]

if mul_after_tanh.output[0] not in input_name_to_nodes:
return
children = input_name_to_nodes[mul_after_tanh.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return
mul_half = children[0]
if not self.model.has_constant_input(mul_half, 0.5):
return

root_input = mul_after_tanh.input[0 if mul_after_tanh.input[1] == add_after_tanh.output[0] else 1]

mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
if mul_before_tanh is None:
return

i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001)
if i < 0:
return

add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node)
if add_before_tanh is None:
return

if add_before_tanh.input[0] == root_input:
another = 1
elif add_before_tanh.input[1] == root_input:
another = 0
else:
return

mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", another, output_name_to_node)
if mul_after_pow is None:
return

i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001)
if i < 0:
return

mul = self.model.match_parent(mul_after_pow, "Mul", 0 if i == 1 else 1, output_name_to_node)
if mul is None:
return

if mul.input[0] == root_input:
another = 1
elif mul.input[1] == root_input:
another = 0
else:
return

mul2 = self.model.match_parent(mul, "Mul", another, output_name_to_node)
if mul2 is None:
return

if mul2.input[0] != root_input or mul2.input[1] != root_input:
return

subgraph_nodes = [
mul2,
mul,
mul_after_pow,
add_before_tanh,
mul_before_tanh,
tanh_node,
add_after_tanh,
mul_after_tanh,
mul_half,
]

if not self.model.is_safe_to_fuse_nodes(
subgraph_nodes,
[mul_half.output[0]],
input_name_to_nodes,
output_name_to_node,
):
return

self.nodes_to_remove.extend(subgraph_nodes)
fused_node = helper.make_node(
"FastGelu",
inputs=[root_input],
outputs=mul_half.output,
name=self.model.create_node_name("FastGelu"),
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
return True
Loading
Loading