Skip to content

This is the official code for the paper "Booster: Tackling Harmful Fine-tuning for Large Language Models via Attenuating Harmful Perturbation" (ICLR2025).

License

Notifications You must be signed in to change notification settings

git-disl/Booster

Repository files navigation

Booster: Tackling Harmful Fine-tuning for Large Language Models via Attenuating Harmful Perturbation

Booster is an alignment stage safety alignment. The idea is to strenghten the model's robustness with alignment/harmful dataset.

The algorithm of Booster is as follows.

Main code logistic

We implement a cusomized trainer (BoosterAlignmentTrainer) on top of the original HuggingFace Trainer. To achieve Booster, we append several forward/backdward passes according to the psedo-agorithm. Specifically, in trainer_step(), we use the following logistic:

# first backward gradient for harmful dataset    
with self.compute_loss_context_manager():
    loss =  self.compute_loss(model, harmful_inputs)
if self.use_apex:
    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
        scaled_loss.backward()
else:
    self.accelerator.backward(loss)
stored_grads = {name: param.grad.data.clone() for name, param in model.named_parameters() if param.requires_grad}
# Take step with the harmful perturbation
with torch.no_grad():
    grad_norm = self._grad_norm(stored_grads)+ 1e-7
    # perturb the weights
    for name, param in model.named_parameters():
        if param.requires_grad:
            param.data -= self.args.alpha*stored_grads[name]/grad_norm

# backward the harmful gradient after harmful perturbation
with self.compute_loss_context_manager():
    loss2 =  self.compute_loss(model, harmful_inputs)
if self.use_apex:
    with amp.scale_loss(loss2, self.optimizer) as scaled_loss:
        scaled_loss.backward()
else:
    self.accelerator.backward(loss2)
perturb_grads = {name: param.grad.clone() for name, param in model.named_parameters() if param.requires_grad}
# calculate the alignment grad
with self.compute_loss_context_manager():
    loss3 =  self.compute_loss(model, inputs)
if self.use_apex:
    with amp.scale_loss(loss3, self.optimizer) as scaled_loss:
        scaled_loss.backward()
else:
    self.accelerator.backward(loss3)
# Finally, sum the grad
for name, param in model.named_parameters():
    if param.requires_grad:
        param.grad.data=param.grad.data  + (self.args.lamb)*stored_grads[name] -self.args.lamb* perturb_grads[name]

Package requirement

The package requirement is listed in booster.yml and booster_pip.txt. Run the following code to install the packages with anaconda and pip.

conda env create -f booster.yml
pip install -r booster_pip.txt

Data preparation

For safety alignment, please download the safety alignment dataset from this link, and put the json file under \data directory.

For finetuning task, we first need to run the following scripts to prepare the sueprvised finetuning data.

cd sst2
python build_dataset.py
cd ../gsm8k
python build_dataset.py
cd ../ag_news
python build_dataset.py
cd ..

Huggingface Llama2 access

Llama2-7B is a gated repo, which need a formal request to get access to the model. Check out https://huggingface.co/meta-llama/Llama-2-7b-hf. After applying permission from meta, you should be able to access the model, but you first need to enter your token in the file huggingface_token.txt.

Example command to run

We prepare scripts for re-producing all the experiments in the paper (check out the script directory). We recommend to use Slurm to reproduce the results as the logging file will be automatically organized into the script directory (if you don't use Slurm, just replace sbatch with bash in our example).

We first run SFT to produce the aligned model.

cd script/alignment
sbatch  smooth_align.sh

Then we finetune the model using 10% of harmful data with a total number of 1000 samples from SST2 dataset.

cd ../finetune
sbatch  smooth_poison_ratio.sh 0.1

About

This is the official code for the paper "Booster: Tackling Harmful Fine-tuning for Large Language Models via Attenuating Harmful Perturbation" (ICLR2025).

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published