Skip to content

Commit

Permalink
In ShardedLlamaTest cange seq_lens type to torch.int64
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Oct 2, 2024
1 parent 0605ddf commit 5375807
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion sharktank/tests/models/llama/sharded_llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def setUp(self):
vocab_size=self.vocabulary_size,
)
self.prefill_seq_lens = torch.tensor(
[14, 9, self.block_seq_stride - 1], dtype=torch.int32
[14, 9, self.block_seq_stride - 1], dtype=torch.int64
)

def make_prefill_args(self, model: PagedLlamaModelV1) -> Dict[str, Any]:
Expand Down

0 comments on commit 5375807

Please sign in to comment.