Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LayerNormalization broadcast (limited support for axis=2) #23297

Merged
merged 11 commits into from
Jan 11, 2025

Conversation

tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Jan 9, 2025

Description

Spec of LayerNormalization supports broadcasting (tensors Scale and B should be unidirectional broadcastable to tensor X).
https://onnx.ai/onnx/operators/onnx__LayerNormalization.html
However, current implementation only allow scale and bias size to be X.shape()[axis:].

Example of input tensors that normalized with axis=2:

X shape Scale shape B shape Before After
(B, S, D) (D) (D) Supported Supported
(B, S, D) (1, 1, D) (1, 1, D) Supported Supported
(B, S, D) (B, 1, D) (B, 1, D) Not Supported Supported
(B, S, D) (1, S, D) (1, S, D) Not Supported Supported
(B, S, D) (B, S, D) (B, S, D) Not Supported Supported

Here we add limited support: axis=2; scale/bias has same shape; scale/bias/X have same number of dimensions. It could support common use case in LLM and vision models.

Motivation and Context

Support Stable Diffusion 3.x and Flux model.

@tianleiwu tianleiwu marked this pull request as draft January 9, 2025 07:26
@tianleiwu tianleiwu marked this pull request as ready for review January 9, 2025 22:28
@tianleiwu tianleiwu requested a review from jiafatom January 9, 2025 22:29
@tianleiwu tianleiwu merged commit 73f5b0c into main Jan 11, 2025
98 checks passed
@tianleiwu tianleiwu deleted the tlwu/layer_norm_broadcast branch January 11, 2025 05:57
guschmue pushed a commit that referenced this pull request Jan 12, 2025
### Description

Spec of LayerNormalization supports broadcasting (tensors Scale and B
should be unidirectional broadcastable to tensor X).
https://onnx.ai/onnx/operators/onnx__LayerNormalization.html
However, current implementation only allow scale and bias size to be
X.shape()[axis:].

Example of input tensors that normalized with axis=2:

| X shape |  Scale shape | B shape | Before | After |
| - | - | - | - | - |
| (B, S, D) | (D) | (D) | Supported | Supported |
| (B, S, D) | (1, 1, D) | (1, 1, D) | Supported | Supported |
| (B, S, D) | (B, 1, D) | (B, 1, D) | Not Supported | Supported |
| (B, S, D) | (1, S, D) | (1, S, D) | Not Supported | Supported |
| (B, S, D) | (B, S, D) | (B, S, D) | Not Supported | Supported |


Here we add limited support: axis=2; scale/bias has same shape;
scale/bias/X have same number of dimensions. It could support common use
case in LLM and vision models.

### Motivation and Context

Support Stable Diffusion 3.x and Flux model.
tianleiwu added a commit that referenced this pull request Jan 14, 2025
### Description

It has dependency on the following PRs:
- #23297

Optimize the ONNX pipeline for Stable Diffusion 3.x and Flux 1.0 models
(fp32 or fp16).
- [x] Update optimize_pipeline script
- [x] Update benchmkark script
- [x] Update document about Stable Diffusion 3.x and Flux 1.0 models
- [x] Add graph optimizations for MMDit model
  - [x] FastGelu fusion
  - [x]  RMSNorm fusion
  - [x]  MultiHeadAttention fusion
- [x] Add graph optimizations for Flux transformer models
  - [x]  MultiHeadAttention fusion
- [x] Update graph optimizations for t5
- [x] Add tests

Optimize the ONNX pipeline for Stable Diffusion 3.x and Flux 1.0 models:
```
python optimize_pipeline.py -i ./flux1_schnell_onnx/fp32 -o ./flux1_schnell_onnx/fp16 --float16

  Optimize flux1_schnell_onnx/fp32/transformer/model.onnx ...
  Fused LayerNormalization: 115
  Fused SimplifiedLayerNormalization: 152
  Fused FastGelu: 76
  Fused MultiHeadAttention: 57
```

### H100 Benchmark Results

* GPU: NVIDIA H100 80GB HBM3
* Image Size: 1024x1024
* Batch Size: 1

Model | Steps | Precision | Engine | Latency (Seconds) | GPU Memory (MB)
-- | -- | -- | -- | -- | --
Flux 1.0 Dev | 50 | BF16 | Torch 2.5.1 (compile) | 8.198 | 37,603
Flux 1.0 Dev | 50 | FP16+BF16 | Optimum (ORT) | 10.762 | 41,469
Flux 1.0 Dev | 50 | FP16+FP32 | Optimum (ORT) | 10.891 | 43,545
Flux 1.0 Dev | 50 | BF16 | Torch 2.5.1 (eager) | 12.339 | 36,651
Flux 1.0 Schnell | 4 | BF16 | Torch 2.5.1 (compile) | 0.775 | 37,857
Flux 1.0 Schnell | 4 | FP16+BF16 | Optimum (ORT) | 0.931 | 41,433
Flux 1.0 Schnell | 4 | FP16+FP32 | Optimum (ORT) | 0.939 | 43,809
Flux 1.0 Schnell | 4 | BF16 | Torch 2.5.1 (eager) | 1.120 | 36,629
SD 3.5 Large | 50 | BF16 | Torch 2.5.1 (compile) | 7.466 | 32,217
SD 3.5 Large | 50 | FP16+BF16 | Optimum (ORT) | 10.275 | 36,609
SD 3.5 Large | 50 | FP16+FP32 | Optimum (ORT) | 10.283 | 36,729
SD 3.5 Large | 50 | BF16 | Torch 2.5.1 (eager) | 11.615 | 31,517
SD 3.5 Medium | 50 | BF16 | Torch 2.5.1 (compile) | 3.240 | 21,143
SD 3.5 Medium | 50 | FP16+BF16 | Optimum (ORT) | 4.799 | 25,097
SD 3.5 Medium | 50 | FP16+FP32 | Optimum (ORT) | 4.838 | 25,109
SD 3.5 Medium | 50 | BF16 | Torch 2.5.1 (eager) | 5.582 | 20,489

### A100 Benchmark Results

* GPU: A100-SXM4-80GB
* Image Size: 1024x1024
* Batch Size: 1

Model | Steps | Precision | Engine | Latency (Seconds) | GPU Memory (MB)
-- | -- | -- | -- | -- | --
Flux 1.0 Dev | 50 | BF16 | Torch 2.5.1 (compile) | 17.593 | 37,723
Flux 1.0 Dev | 50 | FP16+BF16 | Optimum (ORT) | 21.918 | 41,348
Flux 1.0 Dev | 50 | FP16+FP32 | Optimum (ORT) | 22.060 | 44,860
Flux 1.0 Dev | 50 | BF16 | Torch 2.5.1 (eager) | 24.267 | 36,847
Flux 1.0 Schnell | 4 | BF16 | Torch 2.5.1 (compile) | 1.627 | 37,881
Flux 1.0 Schnell | 4 | FP16+BF16 | Optimum (ORT) | 1.884 | 41,537
Flux 1.0 Schnell | 4 | FP16+FP32 | Optimum (ORT) | 1.902 | 44,858
Flux 1.0 Schnell | 4 | BF16 | Torch 2.5.1 (eager) | 2.162 | 36,831
SD 3.5 Large | 50 | BF16 | Torch 2.5.1 (compile) | 15.881 | 32,307
SD 3.5 Large | 50 | FP16+FP32 | Optimum (ORT) | 19.837 | 36,451
SD 3.5 Large | 50 | FP16+BF16 | Optimum (ORT) | 19.964 | 36,461
SD 3.5 Large | 50 | BF16 | Torch 2.5.1 (eager) | 22.477 | 31,513
SD 3.5 Medium | 50 | BF16 | Torch 2.5.1 (compile) | 6.476 | 21,341
SD 3.5 Medium | 50 | FP16+FP32 | Optimum (ORT) | 8.775 | 25,183
SD 3.5 Medium | 50 | BF16 | Torch 2.5.1 (eager) | 10.057 | 20,433

### Future Works

* Triton kernel for matrix multiplication and auto tuning.
* FP8/Int8 quantization

### Motivation and Context

SD 3.5 Architecture:

https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/resolve/main/mmdit-x.png
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants