Skip to content

Commit

Permalink
Add falcon tmodeling file to the qeff
Browse files Browse the repository at this point in the history
Signed-off-by: vbaddi <[email protected]>
  • Loading branch information
vbaddi authored and vbaddi committed May 28, 2024
1 parent 369f453 commit 5229c66
Show file tree
Hide file tree
Showing 5 changed files with 693 additions and 13 deletions.
44 changes: 32 additions & 12 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
CodeGenForCausalLM,
CodeGenModel,
)
from transformers.models.falcon.modeling_falcon import (
FalconAttention,
FalconDecoderLayer,
FalconForCausalLM,
FalconModel,
FalconRotaryEmbedding,
)
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model
from transformers.models.llama.modeling_llama import (
LlamaAttention,
Expand All @@ -34,13 +41,13 @@
)
from transformers.models.mixtral.modeling_mixtral import (
MixtralAttention,
MixtralBLockSparseTop2MLP,
MixtralDecoderLayer,
MixtralForCausalLM,
MixtralModel,
MixtralDecoderLayer,
MixtralSparseMoeBlock,
MixtralBLockSparseTop2MLP,
MixtralRotaryEmbedding,
MixtralRMSNorm,
MixtralRotaryEmbedding,
MixtralSparseMoeBlock,
)
from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel

Expand All @@ -66,6 +73,13 @@
QEffCodeGenForCausalLM,
QEffCodeGenModel,
)
from .models.falcon.modeling_falcon import (
QEffFalconAttention,
QEffFalconDecoderLayer,
QEffFalconForCausalLM,
QEffFalconModel,
QEffFalconRotaryEmbedding,
)
from .models.gpt2.modeling_gpt2 import QEffGPT2Attention, QEffGPT2Block, QEffGPT2LMHeadModel, QEffGPT2Model
from .models.llama.modeling_llama import (
QEffLlamaAttention,
Expand All @@ -81,13 +95,13 @@
QEffMistralRotaryEmbedding,
)
from .models.mixtral_moe.modeling_mixtral import (
QEffMixtralModel,
QEffMixtralRotaryEmbedding,
QEffMixtralAttention,
QEffMixtralForCausalLM,
QEffMixtralBLockSparseTop2MLP,
QEffMixtralDecoderLayer,
QEffMixtralForCausalLM,
QEffMixtralModel,
QEffMixtralRotaryEmbedding,
QEffMixtralSparseMoeBlock,
QEffMixtralBLockSparseTop2MLP,
)
from .models.mpt.modeling_mpt import QEffMptAttention, QEffMptBlock, QEffMptForCausalLM, QEFfMptModel

Expand All @@ -103,6 +117,7 @@
LlamaForCausalLM.__name__,
MistralForCausalLM.__name__,
MixtralForCausalLM.__name__,
FalconForCausalLM.__name__,
]
)

Expand Down Expand Up @@ -145,7 +160,13 @@
MixtralRotaryEmbedding: QEffMixtralRotaryEmbedding,
MixtralRMSNorm: CustomRMSNormAIC,
MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock,
MixtralBLockSparseTop2MLP:QEffMixtralBLockSparseTop2MLP,
MixtralBLockSparseTop2MLP: QEffMixtralBLockSparseTop2MLP,
# Falcon model layers
FalconAttention: QEffFalconAttention,
FalconDecoderLayer: QEffFalconDecoderLayer,
FalconForCausalLM: QEffFalconForCausalLM,
FalconModel: QEffFalconModel,
FalconRotaryEmbedding: QEffFalconRotaryEmbedding,
}


Expand Down Expand Up @@ -190,13 +211,12 @@ def transform(model: nn.Module, form_factor: str = "cloud") -> nn.Module:
Returns:
torch.nn.Module: PyTorch Module with replaced QEff layers.
"""

# Introducnig qeff_transformed attribue in model to check status of transform
if getattr(model, "qeff_transformed", False):
print("Model is already transformed")
return model


if form_factor == "cloud":
# Get Hash of all params for checking later
prior_params_hash = get_params_hash(model)
Expand Down Expand Up @@ -225,7 +245,7 @@ def transform(model: nn.Module, form_factor: str = "cloud") -> nn.Module:
transformers.modeling_attn_mask_utils._prepare_4d_attention_mask = _qeff_prepare_4d_attention_mask
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = _qeff_prepare_4d_causal_attention_mask

setattr(model,'qeff_transformed',True)
setattr(model, "qeff_transformed", True)
return model.eval()

elif form_factor == "edge":
Expand Down
7 changes: 7 additions & 0 deletions QEfficient/transformers/models/falcon/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

Loading

0 comments on commit 5229c66

Please sign in to comment.