diff --git a/msamp/te/extension.py b/msamp/te/extension.py index a67dd7a7..a831240f 100644 --- a/msamp/te/extension.py +++ b/msamp/te/extension.py @@ -124,6 +124,7 @@ def override(): """Override transformer engine extension functions.""" tex.fused_cast_transpose = TeExtensionOverrider.fused_cast_transpose te.cpp_extensions.cast_to_fp8 = TeExtensionOverrider.cast_to_fp8 + te.module.linear.cast_to_fp8 = TeExtensionOverrider.cast_to_fp8 te.cpp_extensions.fp8_cast_transpose_fused = TeExtensionOverrider.fp8_cast_transpose_fused