From 043849904eed0b856d7cb50ce866d04bd766ea0e Mon Sep 17 00:00:00 2001 From: Charlie Fu Date: Fri, 14 Jun 2024 10:33:37 -0500 Subject: [PATCH] Use scaled mm for untuned fp8 gemm (#50) * update quark quantizer command * typo * Using scaled_mm for untuned gemm * remove comment * fix yapf --- .../layers/quantization/fp8_rocm.py | 70 +++++++++---------- 1 file changed, 33 insertions(+), 37 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8_rocm.py b/vllm/model_executor/layers/quantization/fp8_rocm.py index ddb83a6ed452e..ddd3b304280e7 100644 --- a/vllm/model_executor/layers/quantization/fp8_rocm.py +++ b/vllm/model_executor/layers/quantization/fp8_rocm.py @@ -24,10 +24,8 @@ class Fp8RocmConfig(QuantizationConfig): def __init__(self) -> None: - # self.quantized_weights_path = config["quantized_weights"] self._tuned = {} gemm_type = os.getenv("FP8_GEMM", "fp8_16") - #print(f"Integral Cross factor = {self.factor}") if gemm_type == "fp8_8": self.gemm_method = Fp8RocmLinearMethod.apply_fp8_8 tuned_filename = "/tmp/tuned_fp8_8.csv" @@ -220,23 +218,15 @@ def apply_fp8_16( algo = self._config._tuned.get((m, n, k)) if algo is None: - import os - - if os.getenv("TUNE_FP8") == "1": - try: - df = pd.read_csv("/tmp/fp8_shapes.csv") - except (IOError, pd.errors.EmptyDataError, - pd.errors.ParserError): - df = pd.DataFrame(columns=["M", "N", "K"]) - df = pd.concat( - [df, pd.DataFrame({ - "M": [m], - "N": [n], - "K": [k] - })]).drop_duplicates() - df.to_csv("/tmp/fp8_shapes.csv", index=False) - algo = 0 - res = vllm_ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo)) + _save_shape(m, n, k) + res, _ = torch._scaled_mm(x8, + weight.t(), + out_dtype=x.dtype, + scale_a=asf, + scale_b=wsf, + bias=bias) + else: + res = vllm_ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo)) return res def apply_fp8_8( @@ -257,24 +247,16 @@ def apply_fp8_8( algo = self._config._tuned.get((m, n, k)) if algo is None: - import os - - if os.getenv("TUNE_FP8") == "1": - try: - df = pd.read_csv("/projects/fp8_shapes.csv") - except (IOError, pd.errors.EmptyDataError, - pd.errors.ParserError): - df = pd.DataFrame(columns=["M", "N", "K"]) - df = pd.concat( - [df, pd.DataFrame({ - "M": [m], - "N": [n], - "K": [k] - })]).drop_duplicates() - df.to_csv("/tmp/fp8_shapes.csv", index=False) - algo = 0 - - res = vllm_ops.fp8_gemm(x8, weight.t(), asf, wsf, osf, int(algo)) + _save_shape(m, n, k) + res, _ = torch._scaled_mm(x8, + weight.t(), + out_dtype=x8.dtype, + scale_a=asf, + scale_b=wsf, + scale_result=osf, + bias=bias) + else: + res = vllm_ops.fp8_gemm(x8, weight.t(), asf, wsf, osf, int(algo)) res16 = torch.empty_like(res, dtype=torch.float16) vllm_ops.convert_fp8(res16, res, 1 / osf) return res16 @@ -308,3 +290,17 @@ def _per_tensor_dequantize(tensor: torch.Tensor, fake_qweight = tensor.to(torch.float16) dq_weight = fake_qweight * inv_scale return dq_weight + + +def _save_shape(m, n, k): + if os.getenv("TUNE_FP8") == "1": + try: + df = pd.read_csv("/tmp/fp8_shapes.csv") + except (IOError, pd.errors.EmptyDataError, pd.errors.ParserError): + df = pd.DataFrame(columns=["M", "N", "K"]) + df = pd.concat([df, pd.DataFrame({ + "M": [m], + "N": [n], + "K": [k] + })]).drop_duplicates() + df.to_csv("/tmp/fp8_shapes.csv", index=False)