Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetuning on a sequenceClassification task #34

Open
leannmlindsey opened this issue Mar 8, 2024 · 8 comments
Open

Finetuning on a sequenceClassification task #34

leannmlindsey opened this issue Mar 8, 2024 · 8 comments

Comments

@leannmlindsey
Copy link

I am interested in using this model on some classification tasks, but when I have tried to set up the fine_tuning script using AutoModelForSequenceClassification, I get an error (see below). I assume this is because the StripedHyena is a new type of model. Do you have any suggestions?

Thank you.
LeAnn

Traceback (most recent call last):
File "/home/llindsey1/CHPC/evo_finetune.py", line 32, in
seq_classification_model = AutoModelForSequenceClassification.from_config(config)
File "/home/llindsey1/.conda/envs/EVO/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 437, in from_config
raise ValueError(
ValueError: Unrecognized configuration class <class 'transformers_modules.togethercomputer.evo-1-131k-base.8eb9480ea22de5f86eeebc1199a76b63b42d7170.configuration_hyena.StripedHyenaConfig'> for this kind of AutoModel: AutoModelForSequenceClassification.
Model type should be one of AlbertConfig, BartConfig, BertConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BloomConfig, CamembertConfig, CanineConfig, LlamaConfig, ConvBertConfig, CTRLConfig, Data2VecTextConfig, DebertaConfig, DebertaV2Config, DistilBertConfig, ElectraConfig, ErnieConfig, ErnieMConfig, EsmConfig, FalconConfig, FlaubertConfig, FNetConfig, FunnelConfig, GemmaConfig, GPT2Config, GPT2Config, GPTBigCodeConfig, GPTNeoConfig, GPTNeoXConfig, GPTJConfig, IBertConfig, LayoutLMConfig, LayoutLMv2Config, LayoutLMv3Config, LEDConfig, LiltConfig, LlamaConfig, LongformerConfig, LukeConfig, MarkupLMConfig, MBartConfig, MegaConfig, MegatronBertConfig, MistralConfig, MixtralConfig, MobileBertConfig, MPNetConfig, MptConfig, MraConfig, MT5Config, MvpConfig, NezhaConfig, NystromformerConfig, OpenLlamaConfig, OpenAIGPTConfig, OPTConfig, PerceiverConfig, PersimmonConfig, PhiConfig, PLBartConfig, QDQBertConfig, Qwen2Config, ReformerConfig, RemBertConfig, RobertaConfig, RobertaPreLayerNormConfig, RoCBertConfig, RoFormerConfig, SqueezeBertConfig, StableLmConfig, T5Config, TapasConfig, TransfoXLConfig, UMT5Config, XLMConfig, XLMRobertaConfig, XLMRobertaXLConfig, XLNetConfig, XmodConfig, YosoConfig.

@xliaoyi
Copy link

xliaoyi commented Sep 4, 2024

Hi LeAnn,

Since Evo is designed for sequence generation tasks you probably can't use AutoModelForSequenceClassification to load the config, I tried to add a classification head to fine-tune the Evo to perform binary classification, however, the accuracy is only 50%, I attached my code below. Let me know if you have any other suggestions.

Best,
Liaoyi

import transformers
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig, DataCollatorWithPadding
import os
from sklearn.metrics import accuracy_score
import numpy as np
from typing import Optional

from datasets import Dataset
from transformers import TrainingArguments, Trainer

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["WANDB_DISABLED"] = "true"

model_name = 'togethercomputer/evo-1-8k-base'

model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
model_config.use_cache = False

custom_cache_dir = "./models/evo/cache"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=model_config,
    cache_dir=custom_cache_dir,
    trust_remote_code=True,
    device_map={"":0},
    torch_dtype=torch.float16
)

# Add classification head
num_labels = 2
model.classifier = torch.nn.Sequential(
    torch.nn.Linear(512, 256),
    torch.nn.ReLU(),
    torch.nn.Linear(256, num_labels)
)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = "X" # there is no pad token in evo so I randomly picked one
tokenizer.padding_side = "left"  # since last token for classification we need to pad left side

for p in model.parameters():
    p.requires_grad = False

# Unfreeze the last layer and the classification head
for p in model.backbone.blocks[-1].parameters():
    p.requires_grad = True
for p in model.classifier.parameters():
    p.requires_grad = True

# Load custom datasets
train_df = pd.read_csv("./data/train_data_DNABERT_2_200k/train.csv")
dev_df = pd.read_csv("./data/train_data_DNABERT_2_200k/dev.csv")

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

def preprocess_function(sample):
    inputs = tokenizer(sample['sequence'], padding='max_length', truncation=True, max_length=128)
    inputs['labels'] = sample['label']  # Assuming 'label' is the classification target
    return inputs


train_dataset = Dataset.from_pandas(train_df)
dev_dataset = Dataset.from_pandas(dev_df)

tokenized_train_ds = train_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=12,
)

tokenized_dev_ds = dev_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=12,
)

training_args = TrainingArguments(
    output_dir="./evo_results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    gradient_accumulation_steps=1,
    per_device_train_batch_size=256,
    warmup_steps=1,
    # max_steps=100, # only a demo
    num_train_epochs=5,
    logging_steps=10,
    eval_steps=10,
    logging_strategy="steps",
    bf16=True,
    save_strategy="epoch",
    save_total_limit=3
)

# Modify the custom loss function
def custom_loss(outputs, labels):
    # for Decoder last token do classification
    last_hidden_state = outputs.logits[:, -1, :]
    logits = model.classifier(last_hidden_state)
    loss = torch.nn.functional.cross_entropy(logits, labels)
    
    return loss

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs, output_hidden_states=True)
        loss = custom_loss(outputs, labels)
        return (loss, outputs) if return_outputs else loss

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        model = self.model.eval()

        all_preds = []
        all_labels = []

        for batch in eval_dataloader:
            with torch.no_grad():
                outputs = model(**{k: v.to(model.device) for k, v in batch.items() if k != "labels"})
                last_hidden_state = outputs.logits[:, -1, :]
                logits = model.classifier(last_hidden_state)
                preds = torch.argmax(logits, dim=-1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(batch["labels"].cpu().numpy())

        accuracy = accuracy_score(all_labels, all_preds)
        results = {f"{metric_key_prefix}_accuracy": accuracy}
        
        self.log(results)
        return results

    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
        """
        Override the save_model method to include safe_serialization=False
        """
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        self.model.save_pretrained(output_dir, safe_serialization=False)
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_ds,
    eval_dataset=tokenized_dev_ds,
    data_collator=data_collator,
    compute_metrics=lambda p: {"accuracy": accuracy_score(p.label_ids, p.predictions.argmax(-1))},
)

trainer.train()

trainer.save_model("./evo_results/final_model")

@kawabata-tomoko
Copy link

kawabata-tomoko commented Sep 12, 2024

@xliaoyi, actually, you can use the information before the unembed layer instead of the logits returned by the backbone. I used the embedding information and added two layers: a hidden layer (4096, 8192) and a classifier (8192, 2) to perform a binary classification task. It worked well, achieving a precision of 0.89, recall of 0.84, and AUC of 0.89 after one epoch of fine-tuning.(freezed all parms in backbone except the last block as what you did)
my solution (repo was not completed yet) :here


update here: if I use float32 in hidden layer and classifier layer (bf16 before), it achieves a precision of 0.998, recall of 0.824 after about 0.7 epoch of fine-tuning

@xliaoyi
Copy link

xliaoyi commented Sep 29, 2024

Hi @kawabata-tomoko, thanks for sharing your code and result, but I am a little confused about your model structure. Do you mean you used the embedding weight from pre-trained evo and added a hidden layer and classifier?

Thanks,
Liaoyi

@leannmlindsey
Copy link
Author

Thank you for your suggestions @xliaoyi and @kawabata-tomoko ...what classification task were you testing it on ? I will take a look at your notebook and see if that method will work for my task. Thank you!

@kawabata-tomoko
Copy link

Hi @kawabata-tomoko, thanks for sharing your code and result, but I am a little confused about your model structure. Do you mean you used the embedding weight from pre-trained evo and added a hidden layer and classifier?

Thanks, Liaoyi

Firstly, I rewrite a part of pretrained model (the class StripedHyena) to make sure I could receive the sequence embedding correctly.
For example(you can also inherit this class to achieve this goal):

class StripedHyena(nn.Module):
    ...
    def forward(self, x, inference_params_dict=None, padding_mask=None):
        L = x.shape[1]
        x = self.embedding_layer.embed(x)
        if inference_params_dict is not None:
            x, inference_params_dict_out = self.stateful_forward(
                x,
                inference_params_dict=inference_params_dict,
            )
        else:
            x, inference_params_dict_out = self.stateless_forward(
                x, padding_mask=padding_mask
            )

        x = self.norm(x)
        #adding a judgment branch here
        if self.config.unembed==True:
            x = self.unembed.unembed(x)
        return x, inference_params_dict_out
     ...

After this, I create the class SeqForEvo like this:

class SeqClsForEvo(StripedHyenaPreTrainedModel):
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        model_config = dotdict(config.to_dict())
        self.backbone = StripedHyena(model_config)
        self.backbone.gradient_checkpointing = False
        self.config = config
        vocab_size = config.vocab_size
        if vocab_size % config.make_vocab_size_divisible_by != 0:
            vocab_size += config.make_vocab_size_divisible_by - (
                vocab_size % config.make_vocab_size_divisible_by
            )

        self.vocab_size = vocab_size
        self.num_labels = config.num_labels
        self.hidden = torch.nn.Linear(config.hidden_size,config.hidden_size*2,dtype=torch.float32)#.to(torch.bfloat16)
        self.classifier = torch.nn.Linear(config.hidden_size*2,self.num_labels,dtype=torch.float32)#.to(torch.bfloat16)#load as bf16
        self.ln_hidden = torch.nn.LayerNorm(config.hidden_size*2,dtype=torch.float32)
        self.post_init()
        self.force_dtype()
        
        
    def force_dtype(self):
        self.backbone.to_bfloat16_except_poles_residues() 
        
    def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
        self.backbone.gradient_checkpointing = enable

    def get_input_embeddings(self):
        return self.backbone.embedding_layer
    
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.LongTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        past_key_values=None,
        return_dict: Optional[bool] = None,
        eos_index : Optional[bool] = None 
    ) -> Union[Tuple, SequenceClassifierOutput]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        eos_index = eos_index if eos_index is not None else torch.ones(input_ids.shape[0],1,dtype=int)*input_ids.shape[1]-1

        logits, past_key_values = self.backbone(
            input_ids,
            padding_mask=attention_mask,
            inference_params_dict=past_key_values if use_cache else None,
        )
        # feature=logits[:,-1,:] #use [EOS] Instead [CLS]
        eos_index=eos_index.to(logits.device)#dynamic-adaption [eos] position for each sequence.
        logits = logits.to(dtype=self.hidden.weight.dtype).gather(1, eos_index.unsqueeze(-1).expand(-1, -1, logits.size(-1)))

        # feature.to(self.hidden.weight.dtype)
        logits = self.classifier(self.ln_hidden(torch.tanh(self.hidden(logits))))
        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()#ignoring label:-100

            labels = labels.to(logits.device)
            loss = loss_fct(logits.view(-1,self.num_labels), labels)

        if return_dict:
            return SequenceClassifierOutput(
                loss = loss,
                logits = logits,
                hidden_states = None,
                attentions = None
                )
        else:
            return logits

    @classmethod
    def can_generate(cls) -> bool:
        return False

At this moment, you could load model and start training like this(Added PEFT adapter). By the Way, you could use Max/AvgPool layer instead adding [EOS] token while referring the information of this sequence.

import torch
from evo.srcs.Application.SeqClsForEvo import SeqClsForEvo#make sure you could load this module correctly
from transformers import AutoConfig,AutoTokenizer
from transformers import DefaultDataCollator,Trainer,TrainingArguments
from datasets import load_dataset
from peft import LoraConfig
import wandb 
import os

os.environ["WANDB_PROJECT"]="evo_SARS_cls"
torch.manual_seed(42)
configs=AutoConfig.from_pretrained(
    f"{CONFIG_DIR_PATH}",
    trust_remote_code=True,
    use_cache=False,
    num_labels=2
    )

model = SeqClsForEvo.from_pretrained(
    f"{MODLE_PATH}",
    config=configs,
    torch_dtype=torch.bfloat16
)

#######
# Adding lora adapter
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    modules_to_save=["hidden","classifier"],
    target_modules="all-linear",
    lora_dropout=0.05,
    bias="none",
)
model.add_adapter(lora_config,adapter_name="seq_cls")
#######
# model = model.to(dtype=torch.bfloat16)

tokenizer=AutoTokenizer.from_pretrained(
    f"{TOKENIZER_DIR_PATH}",
    trust_remote_code=True,
    cls_token="@",
    eos_token="&",
    bos_token="^",
    pad_token = 'N'
    )

datacollator = DefaultDataCollator()

training_args=TrainingArguments(
    output_dir=f"{OUTPUT_DIR}",
    evaluation_strategy="steps",
    eval_steps=100,
    save_steps=50,
    save_total_limit=20,
    learning_rate= 5e-6,#EVO use 0.00009698,
    lr_scheduler_type= "cosine",
    warmup_ratio = 0.1,
    weight_decay=0.05,#EVO use 0.1
    num_train_epochs=5,#EVO use 10
    gradient_accumulation_steps=4,#pretrained 8
    per_device_train_batch_size=1,#pretrained 4
    per_device_eval_batch_size=1,#pretrained 4
    neftune_noise_alpha=10.0,
    max_grad_norm=10,
    bf16=True,
    logging_steps =5,
    report_to="wandb"
)


##INITIAL DATASETS
def pack(_tokenizer,max_length,padding="max_length",pad_to_multiple_of=None,return_tensors="pt"):
    def padseq(line):
        inputs=_tokenizer(list(map(lambda x :x+_tokenizer.eos_token,line["seq"])))
        eos_index=list(map(lambda x:[len(x)-1],inputs["input_ids"]))
        input_ids_padded=_tokenizer.pad(
                inputs,
                padding=padding,
                max_length=max_length,
                pad_to_multiple_of=pad_to_multiple_of,
                return_tensors=return_tensors
                )
        return dict(
            input_ids=input_ids_padded["input_ids"],
            attention_mask=input_ids_padded["attention_mask"],
            label=line["label"],
            eos_index=eos_index
        )
    return padseq

train_ds = load_dataset("json",data_files=f"{DATA_PATH}/{DATASET_FILE}")

func=pack(tokenizer,6001,padding="max_length")
train_ds_sp=train_ds.map(
    func,batched=True,num_proc=4)["train"]
train_ds_sp=train_ds_sp.remove_columns("seq")
tempset=train_ds_sp.train_test_split(test_size=0.2,seed=42)
trainset=tempset["train"]
tempset=tempset["test"]
tempset=tempset.train_test_split(test_size=0.5,seed=42)
evalset=tempset["train"]
testset=tempset["test"]
trainset.save_to_disk(f"{DATA_PATH}/trainset", num_proc=os.cpu_count())
evalset.save_to_disk( f"{DATA_PATH}/evalset", num_proc=1)
testset.save_to_disk( f"{DATA_PATH}/testset", num_proc=1)

#You can load dataset from exist file rather than re-bulid them while training.
# from datasets import load_dataset, load_from_disk
# trainset=load_from_disk("{DATA_PATH}/trainset")
# evalset=load_from_disk("{DATA_PATH}/evalset")

def p_count(m):
    ttp=0
    tp=0
    for p in m.parameters():
        c=p.numel()
        if p.requires_grad == True:
            ttp+=c
        tp+=c
    print(f"Total trainable parameters: {ttp}")
    print(f"Total parameters: {tp}")


p_count(model)
print(model.hidden.weight)

import numpy as np
from sklearn.metrics import roc_auc_score
def compute_metrics(p):
    logits, labels = p
    pred=np.argmax(logits, axis=2).T[0]
    TP=np.sum((pred==1)&(labels==1))
    FP=np.sum((pred==1)&(labels==0))
    FN=np.sum((pred==0)&(labels==1))
    TN=np.sum((pred==0)&(labels==0))
    precision=TP/(FP+TP)
    recall=TP/(FN+TP)
    roc=roc_auc_score(
        labels,
        logits[:,:,1].T[0]
        )
    return {"precision":precision,"recall":recall,"roc-auc":roc}
trainer= Trainer(
    model=model,
    args=training_args,
    train_dataset= trainset,
    eval_dataset= evalset,
    data_collator=datacollator,
    compute_metrics=compute_metrics
)
trainer.train()

with PEFT using, I could train with 6k+ length sequence on single A800 80GB device without any VM-optimization technology (like FSDP/ZeRO-n...)

@leannmlindsey
The scores I mentioned before is describing a 16S rRNA classification task (2kbp sequence, Phylum level Binary task).
NOTICE: THIS TASK IS VERY EASY.

@xliaoyi
Copy link

xliaoyi commented Oct 16, 2024

Hi @kawabata-tomoko,

Thanks for your detailed explanation!

If I understand correctly, you isolated the embedding layer from the Evo model and added fully connected layers right after the embedding layer to perform the classification task.

If so, I feel like the model trained in this way lost the cross-attention feature. May I ask if you've compared your model with other Encoder models pre-trained on DNA data for example DNABERT?

Best,
Liaoyi

@yaqisu
Copy link

yaqisu commented Nov 15, 2024

I'm also trying to fine tune the model on some sequence level classification task by simply adding a classifier on top of the pre-trained model, and instead of manually adding a [CLS] kind of token, I'm just using the mean embedding as the sequence level representation. But I also only got an AUC of ~ 0.5 similar to what @xliaoyi got.

@kawabata-tomoko
Copy link

您好 ,

感谢您的详细说明!

如果我理解正确,您将嵌入层与 Evo 模型隔离开来,并在嵌入层之后立即添加全连接层以执行分类任务。

如果是这样,我觉得以这种方式训练的模型失去了交叉注意力功能。请问您是否将自己的模型与使用 DNA 数据进行预训练的其他 Encoder 模型(例如 DNABERT)进行了比较?

Best, 寮邑

In fact, according to my understanding, EVO is a decoder-only model (EVO is based on HyenaBlock, and Hyena uses causal convolution to achieve operations similar to upper triangular filling in the attention matrix in the transformer architecture, see Hyena Hierarchy section 3.3, as shown in the figure below). image

This means that the concept of "cross-attention" does not exist in the EVO model (because there is no Encoder part). Therefore, it is very reasonable to use the hidden state of the model directly for classification (in the original model, "unemb" is actually just a linear layer without bias that converts the hidden state results into multi-classification results in the vocab space, which is essentially the same as the classification work we are doing here).

If you are concerned about this operation, you can refer to a simple example I implemented based on traditional multi-head attention LongNet. I compared the Decoder-Only model using [EOS] classification and the Encoder-only model using [CLS] classification with the same parameters. There is no significant difference between the two, and the Decoder-only model even performs slightly better (as shown in the figure below). image

By the way, even if EVO is not a Decoder-Only/Encoder-Only model, but a Seq2Seq model similar to the original Transformers, the calculation of cross-attention is only an additional layer in the attention part of the decoder layer that uses the key and value produced by the encoder and the value produced by the decoder for calculation. After leaving the decoder layer, the calculation of cross-attention has ended. The unembed part I handle is just the process of decoding the hidden space to the vocab space after all calculations are completed. Therefore, I am not very clear about what you mean by "isolate"

I hope this helps! If you have any other questions or need further assistance, feel free to ask.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants