From 0e93b6483d56927d7bf33ccc45244395235aae58 Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Tue, 29 Oct 2024 18:21:31 -0500 Subject: [PATCH] temporary decompose for decode (#353) --- sharktank/sharktank/models/llama/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index ef3c4800d..656b4432b 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -269,6 +269,7 @@ def decode( for block_idx, block in enumerate(self.attn_blocks): if block_idx == 0: self.trace_tensor(f"llama.attn_block.{block_idx}.input", h) + block.attn.attention_kernel = "decomposed" h = block( h, start_positions=start_positions,