-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Finetuning on a sequenceClassification task #34
Comments
Hi LeAnn, Since Evo is designed for sequence generation tasks you probably can't use Best,
|
@xliaoyi, actually, you can use the information before the unembed layer instead of the logits returned by the backbone. I used the embedding information and added two layers: a hidden layer (4096, 8192) and a classifier (8192, 2) to perform a binary classification task. It worked well, achieving a precision of 0.89, recall of 0.84, and AUC of 0.89 after one epoch of fine-tuning.(freezed all parms in backbone except the last block as what you did) update here: if I use float32 in hidden layer and classifier layer (bf16 before), it achieves a precision of 0.998, recall of 0.824 after about 0.7 epoch of fine-tuning |
Hi @kawabata-tomoko, thanks for sharing your code and result, but I am a little confused about your model structure. Do you mean you used the embedding weight from pre-trained evo and added a hidden layer and classifier? Thanks, |
Thank you for your suggestions @xliaoyi and @kawabata-tomoko ...what classification task were you testing it on ? I will take a look at your notebook and see if that method will work for my task. Thank you! |
Firstly, I rewrite a part of pretrained model (the class class StripedHyena(nn.Module):
...
def forward(self, x, inference_params_dict=None, padding_mask=None):
L = x.shape[1]
x = self.embedding_layer.embed(x)
if inference_params_dict is not None:
x, inference_params_dict_out = self.stateful_forward(
x,
inference_params_dict=inference_params_dict,
)
else:
x, inference_params_dict_out = self.stateless_forward(
x, padding_mask=padding_mask
)
x = self.norm(x)
#adding a judgment branch here
if self.config.unembed==True:
x = self.unembed.unembed(x)
return x, inference_params_dict_out
... After this, I create the class class SeqClsForEvo(StripedHyenaPreTrainedModel):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
model_config = dotdict(config.to_dict())
self.backbone = StripedHyena(model_config)
self.backbone.gradient_checkpointing = False
self.config = config
vocab_size = config.vocab_size
if vocab_size % config.make_vocab_size_divisible_by != 0:
vocab_size += config.make_vocab_size_divisible_by - (
vocab_size % config.make_vocab_size_divisible_by
)
self.vocab_size = vocab_size
self.num_labels = config.num_labels
self.hidden = torch.nn.Linear(config.hidden_size,config.hidden_size*2,dtype=torch.float32)#.to(torch.bfloat16)
self.classifier = torch.nn.Linear(config.hidden_size*2,self.num_labels,dtype=torch.float32)#.to(torch.bfloat16)#load as bf16
self.ln_hidden = torch.nn.LayerNorm(config.hidden_size*2,dtype=torch.float32)
self.post_init()
self.force_dtype()
def force_dtype(self):
self.backbone.to_bfloat16_except_poles_residues()
def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
self.backbone.gradient_checkpointing = enable
def get_input_embeddings(self):
return self.backbone.embedding_layer
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
past_key_values=None,
return_dict: Optional[bool] = None,
eos_index : Optional[bool] = None
) -> Union[Tuple, SequenceClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
eos_index = eos_index if eos_index is not None else torch.ones(input_ids.shape[0],1,dtype=int)*input_ids.shape[1]-1
logits, past_key_values = self.backbone(
input_ids,
padding_mask=attention_mask,
inference_params_dict=past_key_values if use_cache else None,
)
# feature=logits[:,-1,:] #use [EOS] Instead [CLS]
eos_index=eos_index.to(logits.device)#dynamic-adaption [eos] position for each sequence.
logits = logits.to(dtype=self.hidden.weight.dtype).gather(1, eos_index.unsqueeze(-1).expand(-1, -1, logits.size(-1)))
# feature.to(self.hidden.weight.dtype)
logits = self.classifier(self.ln_hidden(torch.tanh(self.hidden(logits))))
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()#ignoring label:-100
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1,self.num_labels), labels)
if return_dict:
return SequenceClassifierOutput(
loss = loss,
logits = logits,
hidden_states = None,
attentions = None
)
else:
return logits
@classmethod
def can_generate(cls) -> bool:
return False At this moment, you could load model and start training like this(Added PEFT adapter). By the Way, you could use Max/AvgPool layer instead adding [EOS] token while referring the information of this sequence. import torch
from evo.srcs.Application.SeqClsForEvo import SeqClsForEvo#make sure you could load this module correctly
from transformers import AutoConfig,AutoTokenizer
from transformers import DefaultDataCollator,Trainer,TrainingArguments
from datasets import load_dataset
from peft import LoraConfig
import wandb
import os
os.environ["WANDB_PROJECT"]="evo_SARS_cls"
torch.manual_seed(42)
configs=AutoConfig.from_pretrained(
f"{CONFIG_DIR_PATH}",
trust_remote_code=True,
use_cache=False,
num_labels=2
)
model = SeqClsForEvo.from_pretrained(
f"{MODLE_PATH}",
config=configs,
torch_dtype=torch.bfloat16
)
#######
# Adding lora adapter
lora_config = LoraConfig(
r=8,
lora_alpha=32,
modules_to_save=["hidden","classifier"],
target_modules="all-linear",
lora_dropout=0.05,
bias="none",
)
model.add_adapter(lora_config,adapter_name="seq_cls")
#######
# model = model.to(dtype=torch.bfloat16)
tokenizer=AutoTokenizer.from_pretrained(
f"{TOKENIZER_DIR_PATH}",
trust_remote_code=True,
cls_token="@",
eos_token="&",
bos_token="^",
pad_token = 'N'
)
datacollator = DefaultDataCollator()
training_args=TrainingArguments(
output_dir=f"{OUTPUT_DIR}",
evaluation_strategy="steps",
eval_steps=100,
save_steps=50,
save_total_limit=20,
learning_rate= 5e-6,#EVO use 0.00009698,
lr_scheduler_type= "cosine",
warmup_ratio = 0.1,
weight_decay=0.05,#EVO use 0.1
num_train_epochs=5,#EVO use 10
gradient_accumulation_steps=4,#pretrained 8
per_device_train_batch_size=1,#pretrained 4
per_device_eval_batch_size=1,#pretrained 4
neftune_noise_alpha=10.0,
max_grad_norm=10,
bf16=True,
logging_steps =5,
report_to="wandb"
)
##INITIAL DATASETS
def pack(_tokenizer,max_length,padding="max_length",pad_to_multiple_of=None,return_tensors="pt"):
def padseq(line):
inputs=_tokenizer(list(map(lambda x :x+_tokenizer.eos_token,line["seq"])))
eos_index=list(map(lambda x:[len(x)-1],inputs["input_ids"]))
input_ids_padded=_tokenizer.pad(
inputs,
padding=padding,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors
)
return dict(
input_ids=input_ids_padded["input_ids"],
attention_mask=input_ids_padded["attention_mask"],
label=line["label"],
eos_index=eos_index
)
return padseq
train_ds = load_dataset("json",data_files=f"{DATA_PATH}/{DATASET_FILE}")
func=pack(tokenizer,6001,padding="max_length")
train_ds_sp=train_ds.map(
func,batched=True,num_proc=4)["train"]
train_ds_sp=train_ds_sp.remove_columns("seq")
tempset=train_ds_sp.train_test_split(test_size=0.2,seed=42)
trainset=tempset["train"]
tempset=tempset["test"]
tempset=tempset.train_test_split(test_size=0.5,seed=42)
evalset=tempset["train"]
testset=tempset["test"]
trainset.save_to_disk(f"{DATA_PATH}/trainset", num_proc=os.cpu_count())
evalset.save_to_disk( f"{DATA_PATH}/evalset", num_proc=1)
testset.save_to_disk( f"{DATA_PATH}/testset", num_proc=1)
#You can load dataset from exist file rather than re-bulid them while training.
# from datasets import load_dataset, load_from_disk
# trainset=load_from_disk("{DATA_PATH}/trainset")
# evalset=load_from_disk("{DATA_PATH}/evalset")
def p_count(m):
ttp=0
tp=0
for p in m.parameters():
c=p.numel()
if p.requires_grad == True:
ttp+=c
tp+=c
print(f"Total trainable parameters: {ttp}")
print(f"Total parameters: {tp}")
p_count(model)
print(model.hidden.weight)
import numpy as np
from sklearn.metrics import roc_auc_score
def compute_metrics(p):
logits, labels = p
pred=np.argmax(logits, axis=2).T[0]
TP=np.sum((pred==1)&(labels==1))
FP=np.sum((pred==1)&(labels==0))
FN=np.sum((pred==0)&(labels==1))
TN=np.sum((pred==0)&(labels==0))
precision=TP/(FP+TP)
recall=TP/(FN+TP)
roc=roc_auc_score(
labels,
logits[:,:,1].T[0]
)
return {"precision":precision,"recall":recall,"roc-auc":roc}
trainer= Trainer(
model=model,
args=training_args,
train_dataset= trainset,
eval_dataset= evalset,
data_collator=datacollator,
compute_metrics=compute_metrics
)
trainer.train() with PEFT using, I could train with 6k+ length sequence on single A800 80GB device without any VM-optimization technology (like FSDP/ZeRO-n...) @leannmlindsey |
Hi @kawabata-tomoko, Thanks for your detailed explanation! If I understand correctly, you isolated the embedding layer from the Evo model and added fully connected layers right after the embedding layer to perform the classification task. If so, I feel like the model trained in this way lost the cross-attention feature. May I ask if you've compared your model with other Encoder models pre-trained on DNA data for example DNABERT? Best, |
I'm also trying to fine tune the model on some sequence level classification task by simply adding a classifier on top of the pre-trained model, and instead of manually adding a [CLS] kind of token, I'm just using the mean embedding as the sequence level representation. But I also only got an AUC of ~ 0.5 similar to what @xliaoyi got. |
In fact, according to my understanding, EVO is a decoder-only model (EVO is based on HyenaBlock, and Hyena uses causal convolution to achieve operations similar to upper triangular filling in the attention matrix in the transformer architecture, see Hyena Hierarchy section 3.3, as shown in the figure below). This means that the concept of "cross-attention" does not exist in the EVO model (because there is no Encoder part). Therefore, it is very reasonable to use the hidden state of the model directly for classification (in the original model, "unemb" is actually just a linear layer without bias that converts the hidden state results into multi-classification results in the vocab space, which is essentially the same as the classification work we are doing here). If you are concerned about this operation, you can refer to a simple example I implemented based on traditional multi-head attention LongNet. I compared the Decoder-Only model using [EOS] classification and the Encoder-only model using [CLS] classification with the same parameters. There is no significant difference between the two, and the Decoder-only model even performs slightly better (as shown in the figure below). By the way, even if EVO is not a Decoder-Only/Encoder-Only model, but a Seq2Seq model similar to the original Transformers, the calculation of cross-attention is only an additional layer in the attention part of the decoder layer that uses the key and value produced by the encoder and the value produced by the decoder for calculation. After leaving the decoder layer, the calculation of cross-attention has ended. The unembed part I handle is just the process of decoding the hidden space to the vocab space after all calculations are completed. Therefore, I am not very clear about what you mean by "isolate" I hope this helps! If you have any other questions or need further assistance, feel free to ask. |
I am interested in using this model on some classification tasks, but when I have tried to set up the fine_tuning script using AutoModelForSequenceClassification, I get an error (see below). I assume this is because the StripedHyena is a new type of model. Do you have any suggestions?
Thank you.
LeAnn
Traceback (most recent call last):
File "/home/llindsey1/CHPC/evo_finetune.py", line 32, in
seq_classification_model = AutoModelForSequenceClassification.from_config(config)
File "/home/llindsey1/.conda/envs/EVO/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 437, in from_config
raise ValueError(
ValueError: Unrecognized configuration class <class 'transformers_modules.togethercomputer.evo-1-131k-base.8eb9480ea22de5f86eeebc1199a76b63b42d7170.configuration_hyena.StripedHyenaConfig'> for this kind of AutoModel: AutoModelForSequenceClassification.
Model type should be one of AlbertConfig, BartConfig, BertConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BloomConfig, CamembertConfig, CanineConfig, LlamaConfig, ConvBertConfig, CTRLConfig, Data2VecTextConfig, DebertaConfig, DebertaV2Config, DistilBertConfig, ElectraConfig, ErnieConfig, ErnieMConfig, EsmConfig, FalconConfig, FlaubertConfig, FNetConfig, FunnelConfig, GemmaConfig, GPT2Config, GPT2Config, GPTBigCodeConfig, GPTNeoConfig, GPTNeoXConfig, GPTJConfig, IBertConfig, LayoutLMConfig, LayoutLMv2Config, LayoutLMv3Config, LEDConfig, LiltConfig, LlamaConfig, LongformerConfig, LukeConfig, MarkupLMConfig, MBartConfig, MegaConfig, MegatronBertConfig, MistralConfig, MixtralConfig, MobileBertConfig, MPNetConfig, MptConfig, MraConfig, MT5Config, MvpConfig, NezhaConfig, NystromformerConfig, OpenLlamaConfig, OpenAIGPTConfig, OPTConfig, PerceiverConfig, PersimmonConfig, PhiConfig, PLBartConfig, QDQBertConfig, Qwen2Config, ReformerConfig, RemBertConfig, RobertaConfig, RobertaPreLayerNormConfig, RoCBertConfig, RoFormerConfig, SqueezeBertConfig, StableLmConfig, T5Config, TapasConfig, TransfoXLConfig, UMT5Config, XLMConfig, XLMRobertaConfig, XLMRobertaXLConfig, XLNetConfig, XmodConfig, YosoConfig.
The text was updated successfully, but these errors were encountered: