From 5375807f808d1bb2987b37d15ce64b6128453262 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Wed, 2 Oct 2024 16:21:42 -0500 Subject: [PATCH] In ShardedLlamaTest cange seq_lens type to torch.int64 --- sharktank/tests/models/llama/sharded_llama_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py index bec2e9775..f7757a6a0 100644 --- a/sharktank/tests/models/llama/sharded_llama_test.py +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -58,7 +58,7 @@ def setUp(self): vocab_size=self.vocabulary_size, ) self.prefill_seq_lens = torch.tensor( - [14, 9, self.block_seq_stride - 1], dtype=torch.int32 + [14, 9, self.block_seq_stride - 1], dtype=torch.int64 ) def make_prefill_args(self, model: PagedLlamaModelV1) -> Dict[str, Any]: