From 61350c55478fba29ecf40940a629a3e7ce008a05 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Wed, 7 Aug 2024 16:11:57 -0700 Subject: [PATCH] pass batch_dim_idx to deepspeed sequence parallel distributed attention for supporting batch size larger than 1 (#433) Co-authored-by: Jinghan Yao --- megatron/model/transformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 7467190582..036c11566a 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -817,12 +817,14 @@ def forward(self, hidden_states, attention_mask, # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) if self.enable_ds_sequence_parallel: + batch_dim_idx = 1 if self.use_flash_attn: if not self.use_flash_attn_triton: query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous() for x in (query_layer, key_layer, value_layer)] + batch_dim_idx = 0 - context_layer = self.dist_attn(query_layer, key_layer, value_layer) + context_layer = self.dist_attn(query_layer, key_layer, value_layer, batch_dim_idx) if not self.use_flash_attn_triton: context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()