Skip to content

Commit

Permalink
Merge branch 'main' into res_add_rmsnorm
Browse files Browse the repository at this point in the history
  • Loading branch information
sanyalington authored Nov 16, 2023
2 parents 7362eca + 79ed694 commit cbc8d84
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 18 deletions.
10 changes: 9 additions & 1 deletion benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def main(args: argparse.Namespace):
max_num_seqs=args.batch_size,
max_num_batched_tokens=args.batch_size * args.input_len,
trust_remote_code=args.trust_remote_code,
use_cuda_graph=args.use_cuda_graph
)

sampling_params = SamplingParams(
Expand Down Expand Up @@ -68,7 +69,8 @@ def run_to_completion(profile: bool = False):
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies.append(run_to_completion(profile=False))
print(latencies)
print(f'Avg latency: {np.mean(latencies)} seconds')
if torch.distributed.get_rank() == 0:
print(f'Avg latency: {np.mean(latencies)} seconds')
del llm
ray.shutdown()

Expand All @@ -90,6 +92,12 @@ def run_to_completion(profile: bool = False):
help='Number of iterations to run.')
parser.add_argument('--trust-remote-code', action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--use-cuda-graph',
action="store_true",
default=False,
help="Whether to use CUDA graph for inference",
)
args = parser.parse_args()
main(args)
#time.sleep(30) #add sleep for profiling
2 changes: 1 addition & 1 deletion docker/fa_arch.patch
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ index 7396633..2ca5d55 100644
[
"-O3",
"-std=c++20",
- "--offload-arch=gfx941",
- "--offload-arch=gfx942",
"-DNDEBUG",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
67 changes: 66 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ class ModelConfig:
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
use_cuda_graph: Whether to capture and replay CUDA graphs.
seed: Random seed for reproducibility.
max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model.
"""

def __init__(
Expand All @@ -44,6 +47,8 @@ def __init__(
use_dummy_weights: bool,
dtype: str,
seed: int,
max_model_len: Optional[int] = None,
use_cuda_graph: Optional[bool] = False,
) -> None:
self.model = model
self.tokenizer = tokenizer
Expand All @@ -53,10 +58,13 @@ def __init__(
self.use_np_weights = use_np_weights
self.use_dummy_weights = use_dummy_weights
self.seed = seed
self.use_cuda_graph = use_cuda_graph

self.hf_config = get_config(model, trust_remote_code)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_tokenizer_mode()
self.max_model_len = _get_and_verify_max_len(self.hf_config,
max_model_len)

def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
Expand Down Expand Up @@ -117,6 +125,8 @@ def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
return total_num_attention_heads // parallel_config.tensor_parallel_size

def get_max_model_len(self) -> int:
if self.max_model_len is not None:
return self.max_model_len
max_model_len = float("inf")
possible_keys = [
# OPT
Expand Down Expand Up @@ -214,7 +224,7 @@ def __init__(

self.world_size = pipeline_parallel_size * tensor_parallel_size
if self.world_size > 1:
self.worker_use_ray = True
self.worker_use_ray = False
self._verify_args()

def _verify_args(self) -> None:
Expand Down Expand Up @@ -295,3 +305,58 @@ def _get_and_verify_dtype(
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
f"{compute_capability[0]}.{compute_capability[1]}.")
return torch_dtype

def _get_and_verify_max_len(
hf_config: PretrainedConfig,
max_model_len: Optional[int],
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
for key in possible_keys:
max_len_key = getattr(hf_config, key, None)
if max_len_key is not None:
derived_max_model_len = min(derived_max_model_len, max_len_key)
if derived_max_model_len == float("inf"):
if max_model_len is not None:
# If max_model_len is specified, we use it.
return max_model_len

default_max_len = 2048
logger.warning(
"The model's config.json does not contain any of the following "
"keys to determine the original maximum length of the model: "
f"{possible_keys}. Assuming the model's maximum length is "
f"{default_max_len}.")
derived_max_model_len = default_max_len

rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor

if max_model_len is None:
max_model_len = derived_max_model_len
elif max_model_len > derived_max_model_len:
raise ValueError(
f"User-specified max_model_len ({max_model_len}) is greater than "
f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
" in model's config.json). This may lead to incorrect model "
"outputs or CUDA errors. Make sure the value is correct and "
"within the model context size.")
return int(max_model_len)
15 changes: 13 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@ class EngineArgs:
use_dummy_weights: bool = False
dtype: str = 'auto'
seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
block_size: int = 16
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.9
gpu_memory_utilization: float = 0.6
max_num_batched_tokens: int = 2560
max_num_seqs: int = 256
disable_log_stats: bool = False
use_cuda_graph: Optional[bool] = False

def __post_init__(self):
if self.tokenizer is None:
Expand Down Expand Up @@ -84,6 +86,11 @@ def add_cli_args(
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
# Parallel arguments
parser.add_argument('--max-model-len',
type=int,
default=None,
help='model context length. If unspecified, '
'will be automatically derived from the model.')
parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray for distributed serving, will be '
Expand Down Expand Up @@ -130,6 +137,10 @@ def add_cli_args(
parser.add_argument('--disable-log-stats',
action='store_true',
help='disable logging statistics')
parser.add_argument('--use-cuda-graph',
action="store_true",
default=False,
help="Whether to use CUDA graph for inference")
return parser

@classmethod
Expand All @@ -148,7 +159,7 @@ def create_engine_configs(
self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.use_np_weights,
self.use_dummy_weights, self.dtype,
self.seed)
self.seed, self.max_model_len, self.use_cuda_graph)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space)
Expand Down
19 changes: 17 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def _init_workers(self, distributed_init_method: str):
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel

assert self.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
#assert self.parallel_config.world_size == 1, (
# "Ray is required if parallel_config.world_size > 1.")

self.workers: List[Worker] = []
worker = Worker(
Expand Down Expand Up @@ -223,6 +223,21 @@ def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
log_stats=not engine_args.disable_log_stats)
return engine

def warm_up_cuda_graph(self) -> None:
"""Compiles a CUDA graph with batch size of max sequence num."""
if not self.model_config.use_cuda_graph:
return

for i in range(self.scheduler_config.max_num_seqs):
self.add_request(str(i), "a",
SamplingParams(temperature=0.0, max_tokens=2))

self.step()
self.step()

for i in range(self.scheduler_config.max_num_seqs):
self.abort_request(str(i))

def add_request(
self,
request_id: str,
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
tensor_parallel_size: int = 1,
dtype: str = "auto",
seed: int = 0,
use_cuda_graph: Optional[bool] = False,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
Expand All @@ -61,6 +62,7 @@ def __init__(
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
seed=seed,
use_cuda_graph=use_cuda_graph,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(engine_args)
Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Optional

import torch
#from xformers.ops import AttentionBias
Expand Down Expand Up @@ -29,6 +29,7 @@ def __init__(
context_lens: torch.Tensor,
max_context_len: int,
block_tables: torch.Tensor,
use_cuda_graph: Optional[bool] = False,
) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
Expand All @@ -37,6 +38,7 @@ def __init__(
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables
self.use_cuda_graph = use_cuda_graph

self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
Expand Down Expand Up @@ -67,4 +69,5 @@ def __repr__(self) -> str:
f'max_context_len={self.max_context_len}), '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'block_tables={self.block_tables}), '
f'slot_mapping={self.slot_mapping}')
f'slot_mapping={self.slot_mapping})'
f"use_cuda_graph={self.use_cuda_graph})")
Loading

0 comments on commit cbc8d84

Please sign in to comment.