diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index f93e58e61..797f31656 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -79,7 +79,7 @@ def main(): hp, tensor_parallelism_size=tensor_parallelism_size, use_hf=False, - kv_cache_type="direct" if args.bs == [1] else "paged", + kv_cache_type="paged", attention_kernel=args.attention_kernel, block_seq_stride=args.block_seq_stride, )