Skip to content

Commit

Permalink
Use scaled mm for untuned fp8 gemm (#50)
Browse files Browse the repository at this point in the history
* update quark quantizer command

* typo

* Using scaled_mm for untuned gemm

* remove comment

* fix yapf
  • Loading branch information
charlifu authored Jun 14, 2024
1 parent d3da246 commit 0438499
Showing 1 changed file with 33 additions and 37 deletions.
70 changes: 33 additions & 37 deletions vllm/model_executor/layers/quantization/fp8_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@
class Fp8RocmConfig(QuantizationConfig):

def __init__(self) -> None:
# self.quantized_weights_path = config["quantized_weights"]
self._tuned = {}
gemm_type = os.getenv("FP8_GEMM", "fp8_16")
#print(f"Integral Cross factor = {self.factor}")
if gemm_type == "fp8_8":
self.gemm_method = Fp8RocmLinearMethod.apply_fp8_8
tuned_filename = "/tmp/tuned_fp8_8.csv"
Expand Down Expand Up @@ -220,23 +218,15 @@ def apply_fp8_16(

algo = self._config._tuned.get((m, n, k))
if algo is None:
import os

if os.getenv("TUNE_FP8") == "1":
try:
df = pd.read_csv("/tmp/fp8_shapes.csv")
except (IOError, pd.errors.EmptyDataError,
pd.errors.ParserError):
df = pd.DataFrame(columns=["M", "N", "K"])
df = pd.concat(
[df, pd.DataFrame({
"M": [m],
"N": [n],
"K": [k]
})]).drop_duplicates()
df.to_csv("/tmp/fp8_shapes.csv", index=False)
algo = 0
res = vllm_ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo))
_save_shape(m, n, k)
res, _ = torch._scaled_mm(x8,
weight.t(),
out_dtype=x.dtype,
scale_a=asf,
scale_b=wsf,
bias=bias)
else:
res = vllm_ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo))
return res

def apply_fp8_8(
Expand All @@ -257,24 +247,16 @@ def apply_fp8_8(

algo = self._config._tuned.get((m, n, k))
if algo is None:
import os

if os.getenv("TUNE_FP8") == "1":
try:
df = pd.read_csv("/projects/fp8_shapes.csv")
except (IOError, pd.errors.EmptyDataError,
pd.errors.ParserError):
df = pd.DataFrame(columns=["M", "N", "K"])
df = pd.concat(
[df, pd.DataFrame({
"M": [m],
"N": [n],
"K": [k]
})]).drop_duplicates()
df.to_csv("/tmp/fp8_shapes.csv", index=False)
algo = 0

res = vllm_ops.fp8_gemm(x8, weight.t(), asf, wsf, osf, int(algo))
_save_shape(m, n, k)
res, _ = torch._scaled_mm(x8,
weight.t(),
out_dtype=x8.dtype,
scale_a=asf,
scale_b=wsf,
scale_result=osf,
bias=bias)
else:
res = vllm_ops.fp8_gemm(x8, weight.t(), asf, wsf, osf, int(algo))
res16 = torch.empty_like(res, dtype=torch.float16)
vllm_ops.convert_fp8(res16, res, 1 / osf)
return res16
Expand Down Expand Up @@ -308,3 +290,17 @@ def _per_tensor_dequantize(tensor: torch.Tensor,
fake_qweight = tensor.to(torch.float16)
dq_weight = fake_qweight * inv_scale
return dq_weight


def _save_shape(m, n, k):
if os.getenv("TUNE_FP8") == "1":
try:
df = pd.read_csv("/tmp/fp8_shapes.csv")
except (IOError, pd.errors.EmptyDataError, pd.errors.ParserError):
df = pd.DataFrame(columns=["M", "N", "K"])
df = pd.concat([df, pd.DataFrame({
"M": [m],
"N": [n],
"K": [k]
})]).drop_duplicates()
df.to_csv("/tmp/fp8_shapes.csv", index=False)

0 comments on commit 0438499

Please sign in to comment.