Skip to content

Commit

Permalink
Merge pull request #7 from ROCmSoftwarePlatform/vllm_v4
Browse files Browse the repository at this point in the history
V4 update
  • Loading branch information
dllehr-amd authored Nov 6, 2023
2 parents ae21b27 + 6132216 commit 656dd08
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 80 deletions.
12 changes: 10 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,21 @@ RUN git clone https://streamhsa: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@github.
cd vllm-private &&\
pip install -r requirements.txt &&\
pip install typing-extensions==4.8.0 &&\
sed -i 's/gpu_memory_utilization: float = 0.9/gpu_memory_utilization: float = 0.4/g' vllm/engine/arg_utils.py &&\
python setup.py build && python setup.py develop; exit 0
RUN pip install pyarrow Ray
RUN pip install pandas
RUN pip install pyarrow Ray pandas==2.0 numpy==1.20.3

# RUN git clone -b N1 https://streamhsa:[email protected]/amgddm/rocBLAS-internal.git && \
# cd rocBLAS-internal && ./install.sh -idc -a gfx942 &&\
# cd ../ && rm -rf rocBLAS-internal
RUN cd ~ && wget --no-check-certificate \
https://compute-artifactory.amd.com/artifactory/list/rocm-osdb-20.04-deb/compute-rocm-dkms-no-npi-hipclang-12845/hsa-rocr_1.12.0.60000-crdnnh.12845%7E20.04_amd64.deb &&\
dpkg-deb -x hsa-rocr_1.12.0.60000-crdnnh.12845~20.04_amd64.deb . && \
cp opt/rocm-6.0.0-12845/lib/* /opt/rocm-6.0.0-12969/lib/ && rm -rf opt/rocm-6.0.0-12845

ENV HSA_OVERRIDE_GFX_VERSION=9.4.1

RUN rm /opt/rocm/share/rccl/msccl-algorithms/allreduce-allpairs-8n-simple.xml

COPY run_vllm.sh $WORKSPACE_DIR/
COPY run_vllm_tp1.sh $WORKSPACE_DIR/
Expand Down
1 change: 1 addition & 0 deletions docker/Dockerfile_v3
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ RUN git clone https://streamhsa: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@github.
cd vllm-private &&\
pip install -r requirements.txt &&\
pip install typing-extensions==4.8.0 &&\
sed -i 's/gpu_memory_utilization: float = 0.9/gpu_memory_utilization: float = 0.4/g' vllm/engine/arg_utils.py &&\
python setup.py build && python setup.py develop; exit 0
RUN pip install pyarrow Ray pandas==2.0 numpy==1.20.3

Expand Down
70 changes: 23 additions & 47 deletions docker/run_vllm.sh
Original file line number Diff line number Diff line change
@@ -1,55 +1,31 @@
#!/bin/bash
# parameter defalt values
tp=1
VLLM_DIR=$HOME/vllm
GRAD_DIR=$HOME/gradlib
MODEL=/data/llama-2-13b-chat-hf

# pring usage of the parameters
usage() {
echo "Usage: $0 [--tp <n>] [--vllm_dir <path>] [--gradlib_dir <path>] [--model <path>]"
exit 1
}

# parse parameters
while [[ "$#" -gt 0 ]]; do
case $1 in
--tp) tp="$2"; shift ;;
--vllm_dir) VLLM_DIR="$2"; shift ;;
--gradlib_dir) GRAD_DIR="$2"; shift ;;
--model) MODEL="$2"; shift ;;
*) usage ;; # Any other argument will show usage information.
esac
shift # Move to next argument
done

# print parameter settings
echo "tensor parallel: $tp"
echo "vllm_dir: $VLLM_DIR"
echo "gradlib_dir: $GRAD_DIR"
echo "model: $MODEL"

SLOT_DIR=/workspace
VLLM_DIR=$SLOT_DIR/vllm-private
GRAD_DIR=$SLOT_DIR/gradlib
MODEL=/data/llama2-70b-chat
#enable to use triton flash attention
export VLLM_USE_TRITON=1
export VLLM_USE_HIPGRAPH=1

export VLLM_USE_HIPGRAPH=
#export LD_LIBRARY_PATH=/var/lib/jenkins/rccl/build
#set Tensor Parallelism
echo "tuned_gemm_csv: ./tuned_tp$tp.csv" > $VLLM_DIR/tuned_perf_tp$tp.yaml
if [ ! -f $VLLM_DIR/tuned_tp$tp.csv ] ;
then
echo "INFO: No Tuned configs detected. Generating now"
cd $GRAD_DIR
python gemm_tuner.py --model_dir $MODEL --output ../vllm/tuned_tp$tp.csv --tp $tp
fi
export VLLM_PERF_YAML=./tuned_perf_tp$tp.yaml

cd $VLLM_DIR
for gen_len in 1 32;
for tp in 1;
do
for input_len in 512 1024 2048 3072;
echo "tuned_gemm_csv: ./tuned_tp$tp.csv" > $VLLM_DIR/tuned_perf_tp$tp.yaml
if [ ! -f $VLLM_DIR/tuned_tp$tp.csv ] ;
then
echo "INFO: No Tuned configs detected. Generating now"
cd $GRAD_DIR
python gemm_tuner.py --model_dir $MODEL --output $VLLM_DIR/tuned_tp$tp.csv --tp $tp
fi
export VLLM_PERF_YAML=./tuned_perf_tp$tp.yaml

cd $VLLM_DIR
for gen_len in 1 32;
do
echo "================================= RUNNING $MODEL tp$tp $input_len $gen_len ==============================================="
python benchmarks/benchmark_latency.py --model $MODEL --input-len $input_len --output-len $gen_len --batch-size 1 --tensor-parallel-size $tp --num-iters 5
for input_len in 512 1024 2048 3072;
do
echo "================================= RUNNING $MODEL $input_len $gen_len ==============================================="
python benchmarks/benchmark_latency.py --model $MODEL --input-len $input_len --output-len $gen_len --batch-size 1 --tensor-parallel-size $tp --num-iters 5
done
done
done

2 changes: 1 addition & 1 deletion docker/run_vllm_tp1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ GRAD_DIR=$SLOT_DIR/gradlib
MODEL=/data/llama2-70b-chat
#enable to use triton flash attention
export VLLM_USE_TRITON=1
export VLLM_USE_HIPGRAPH=1
export VLLM_USE_HIPGRAPH=
#export LD_LIBRARY_PATH=/var/lib/jenkins/rccl/build
#set Tensor Parallelism
for tp in 1;
Expand Down
41 changes: 11 additions & 30 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,25 +86,14 @@ def __init__(

def forward(self, x, num_generation_tokens):
if num_generation_tokens>0:
gate_up, _ = self.gate_up_proj(x)#,call_hipblaslt=hacks.get('decode_mlpup_call_hipblaslt',0),
#splitm=hacks.get('decode_mlpup_splitm',1),
#splitk=hacks.get('decode_mlpup_splitk',1),
#splits=hacks.get('decode_mlpup_splits',None),
#call_rocsolidx=hacks.get('decode_mlpup_call_rocsolidx',0)
#)
gate_up, _ = self.gate_up_proj(x)
else:
gate_up, _ = self.gate_up_proj(x)#,call_hipblaslt=hacks.get('prefill_mlpup_call_hipblaslt',0),call_rocsolidx=hacks.get('prefill_mlpup_call_rocsolidx',0))
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
if num_generation_tokens>0:
x, _ = self.down_proj(x, graphx=self.graphx)#,call_hipblaslt=hacks.get('decode_mlpdown_call_hipblaslt',0),
#splitm=hacks.get('decode_mlpdown_splitm',1),
#splitk=hacks.get('decode_mlpdown_splitk',1),
#splits=hacks.get('decode_mlpdown_splits',None),
#call_rocsolidx=hacks.get('decode_mlpdown_call_rocsolidx',0),
#graphx=hacks.get('decode_mlpdown_graphx',0)
#)
x, _ = self.down_proj(x, graphx=self.graphx)
else:
x, _ = self.down_proj(x)#,call_hipblaslt=hacks.get('prefill_mlpdown_call_hipblaslt',0),call_rocsolidx=hacks.get('prefill_mlpdown_call_rocsolidx',0))
x, _ = self.down_proj(x)
return x


Expand All @@ -115,6 +104,7 @@ def __init__(
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position_embeddings: int=8192
):
super().__init__()
self.hidden_size = hidden_size
Expand All @@ -129,6 +119,7 @@ def __init__(
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
if os.environ.get('VLLM_USE_HIPGRAPH'):
self.graphx = 1
else:
Expand Down Expand Up @@ -165,28 +156,17 @@ def forward(
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
if input_metadata.num_generation_tokens > 0:
qkv, _ = self.qkv_proj(hidden_states)#,call_hipblaslt=hacks.get('decode_qkv_call_hipblaslt',0),
#splitm=hacks.get('decode_qkv_splitm',1),
#splitk=hacks.get('decode_qkv_splitk',1),
#splits=hacks.get('decode_qkv_splits',None),
#call_rocsolidx=hacks.get('decode_qkv_call_rocsolidx',0)
#)
qkv, _ = self.qkv_proj(hidden_states)
else:
qkv, _ = self.qkv_proj(hidden_states)#,call_hipblaslt=hacks.get('prefill_qkv_call_hipblaslt',0),call_rocsolidx=hacks.get('prefill_qkv_call_rocsolidx',0))
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
if input_metadata.num_generation_tokens > 0:
output, _ = self.o_proj(attn_output, graphx=self.graphx)#,call_hipblaslt=hacks.get('decode_oproj_call_hipblaslt',0),
#splitm=hacks.get('decode_oproj_splitm',1),
#splitk=hacks.get('decode_oproj_splitk',1),
#splits=hacks.get('decode_oproj_splits',None),
#call_rocsolidx=hacks.get('decode_oproj_call_rocsolidx',0),
#graphx=hacks.get('decode_oproj_graphx',0)
#)
output, _ = self.o_proj(attn_output, graphx=self.graphx)
else:
output, _ = self.o_proj(attn_output)#,call_hipblaslt=hacks.get('prefill_oproj_call_hipblaslt',0),call_rocsolidx=hacks.get('prefill_oproj_call_rocsolidx',0))
output, _ = self.o_proj(attn_output)
return output


Expand All @@ -204,6 +184,7 @@ def __init__(self, config: LlamaConfig):
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
max_position_embeddings=config.max_position_embeddings
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
Expand Down

0 comments on commit 656dd08

Please sign in to comment.