-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GPU] rope optimization #27907
[GPU] rope optimization #27907
Conversation
switch (impl_param.get_input_layout(0).data_type) { | ||
case data_types::f16: | ||
params.vec_size = 16; | ||
break; | ||
case data_types::f32: | ||
params.vec_size = 8; | ||
break; | ||
default: | ||
params.vec_size = 1; | ||
break; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this vec_size is a parameter of specific kernel while ocl primitive_impl can be used for multiple kernels. So suggestion is to move it to rope_kernel_ref.cpp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
} else if (params.is_chatglm) { | ||
if (params.support_2d_rope) { | ||
// input [batch_size, seq_length] | ||
// output [batch_size, head_count, seq_length, half_rotary_ndims] | ||
dispatchData.gws = {input.Batch().v * params.head_cnt, | ||
input.Feature().v, | ||
params.rotary_ndims / 2ul}; | ||
params.rotary_ndims / 2ul / params.vec_size}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if half rotary ndims is not divisible by vec_size? I think you should either add fallback to vec_size 1 for such case or add tail processing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fallback to vec_size=1 if half rotary ndims is not divisible by vec_size.
JitConstants GetJitConstants(const rope_params& params, DispatchData dispatchData) const override; | ||
DispatchData SetDefault(const rope_params& params) const override; | ||
private: | ||
mutable size_t vec_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a single object instance of this class used for all the Rope layers in the model. To prevent any issues with Rope having different data types or rotary_ndims (and using different vec_size), it's better to introduce some function like size_t get_vec_size(...)
and call it directly from GetJitConstants()
and SetDefault()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense! Updated it!
#include <string> | ||
|
||
namespace kernel_selector { | ||
ParamsKey RoPEKernelOpt::GetSupportedKey() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need a separate opt kernel instance? I think it's a big code duplication as opt kernel with vec_size 1 is identical to ref. Suggest keeping single kernel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that a simple reference kernel implement would be helpful to reference if there is new rope type need to be added.
Anyhow, keep one rope kernel implement also make sense!
### Details: - Optimize rope opencl kernel to improve its performance - Test result shows it can improve RoPE performance about 50% in average. <html xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=OneNote.File> <meta name=Generator content="Microsoft OneNote 15"> </head> <body lang=en-US style='font-family:Calibri;font-size:11.0pt'> <!--StartFragment--> <div style='direction:ltr'> batch=128, seq_length = 7 | base latency(ns) | optimized latency(ns) | latency decreased | | -- | -- | -- | -- | -- | -- rope_ref_5266667119713786613_0_0__sa, | 921352 | 872395 | 5.31% | RoPETestQwen7b | f32 rope_ref_2672092794364911740_0_0__sa, | 1724374 | 514790 | 70.15% | RoPETestChatGLM | f32 rope_ref_8061762790816124098_0_0__sa, | 633019 | 127186 | 79.91% | RoPETestQwen7b | f16 rope_ref_4392014836945391706_0_0__sa, | 629791 | 518749 | 17.63% | RoPETestLlama2 | f32 rope_ref_13829176589243505378_0_0__sa, | 870312 | 259583 | 70.17% | RoPETestChatGLM | f32 rope_ref_6813544162411765619_0_0__sa, | 749895 | 421875 | 43.74% | RoPETestChatGLM | f16 rope_ref_15054358246334082928_0_0__sa, | 637708 | 45208 | 92.91% | RoPETestFlux | f32 rope_ref_3898891400599565440_0_0__sa, | 378333 | 335937 | 11.21% | RoPETestRotateHalfWithoutTranspose | f32 rope_ref_18119704851383556529_0_0__sa, | 371250 | 208645 | 43.80% | RoPETestChatGLM | f16 rope_ref_17460680473512025171_0_0__sa, | 299166 | 98958 | 66.92% | RoPETestFlux | f16 </div> <!--EndFragment--> </body> </html> ![image](https://github.com/user-attachments/assets/4328b1a7-18ec-485f-abd0-b0fe16785854) ### Tickets: - *CVS-157438*
Details:
Tickets: