diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py index bec2e9775..f7757a6a0 100644 --- a/sharktank/tests/models/llama/sharded_llama_test.py +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -58,7 +58,7 @@ def setUp(self): vocab_size=self.vocabulary_size, ) self.prefill_seq_lens = torch.tensor( - [14, 9, self.block_seq_stride - 1], dtype=torch.int32 + [14, 9, self.block_seq_stride - 1], dtype=torch.int64 ) def make_prefill_args(self, model: PagedLlamaModelV1) -> Dict[str, Any]: