diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index ef3c4800d..656b4432b 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -269,6 +269,7 @@ def decode( for block_idx, block in enumerate(self.attn_blocks): if block_idx == 0: self.trace_tensor(f"llama.attn_block.{block_idx}.input", h) + block.attn.attention_kernel = "decomposed" h = block( h, start_positions=start_positions,