Skip to content

Commit

Permalink
Minor update
Browse files Browse the repository at this point in the history
Signed-off-by: amitraj <[email protected]>
  • Loading branch information
quic-amitraj committed Jan 15, 2025
1 parent 20b4b85 commit 962c865
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions QEfficient/transformers/custom_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,13 @@ def forward(
context_layers = torch.zeros(bsz, self.num_attention_heads, tgt_len, self.attention_head_size)
for iteration in range(num_blocks):
attention_score_current = torch.matmul(
query_layer[:, :, iteration * block_size : (iteration + 1) * block_size, :], key_layer.transpose(2, 3)
query_layer[:, :, iteration * block_size : (iteration + 1) * block_size, :],
key_layer[:, :, iteration * block_size : (iteration + 1) * block_size, :].transpose(2, 3),
) / math.sqrt(self.attention_head_size)
if attention_mask is not None: # no matter the length, we just slice it
attention_score_current = (
attention_score_current
+ attention_mask[:, :, iteration * block_size : (iteration + 1) * block_size, :]
+ attention_mask[:, :, :, iteration * block_size : (iteration + 1) * block_size]
)
# upcast attention to fp32
attention_probs = nn.functional.softmax(attention_score_current, dim=-1, dtype=torch.float32).to(
Expand All @@ -83,7 +84,9 @@ def forward(
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = torch.matmul(
attention_probs, value_layer[:, :, iteration * block_size : (iteration + 1) * block_size, :]
)
context_layers[:, :, iteration * block_size : (iteration + 1) * block_size, :] = context_layer

context_layers = context_layers.permute(0, 2, 1, 3).contiguous()
Expand Down

0 comments on commit 962c865

Please sign in to comment.