diff --git a/examples/vlbart/README.md b/examples/vlbart/README.md new file mode 100644 index 0000000..3dee3aa --- /dev/null +++ b/examples/vlbart/README.md @@ -0,0 +1,3 @@ +# ReFT + VLBart Experiment + +Try out ReFT with vision-language! Go to `ReftDora/image_video_text_understanding/VL-T5/src/Reft_Injection.ipynb` for instructions. diff --git a/examples/vlbart/ReftDora/.gitattributes b/examples/vlbart/ReftDora/.gitattributes new file mode 100644 index 0000000..7fe70d7 --- /dev/null +++ b/examples/vlbart/ReftDora/.gitattributes @@ -0,0 +1 @@ +*.json filter=lfs diff=lfs merge=lfs -text diff --git a/examples/vlbart/ReftDora/.gitignore b/examples/vlbart/ReftDora/.gitignore new file mode 100644 index 0000000..6f51512 --- /dev/null +++ b/examples/vlbart/ReftDora/.gitignore @@ -0,0 +1,6 @@ +instruction_tuning/instruct/* +instruction_tuning/answers/* +instruction_tuning/peft/* +instruction_tuning/COPYRIGHT.txt +instruction_tuning/get_avg_score.py +instruction_tuning/Software Evaluation License.pdf \ No newline at end of file diff --git a/examples/vlbart/ReftDora/LICENSE b/examples/vlbart/ReftDora/LICENSE new file mode 100644 index 0000000..43c5b76 --- /dev/null +++ b/examples/vlbart/ReftDora/LICENSE @@ -0,0 +1,83 @@ +Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +NVIDIA Source Code License for DoRA + +======================================================================= + +1. Definitions + +“Licensor” means any person or entity that distributes its Work. + +“Work” means (a) the original work of authorship made available under +this license, which may include software, documentation, or other files, +and (b) any additions to or derivative works thereof that are made +available under this license. + +The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” +have the meaning as provided under U.S. copyright law; provided, however, +that for the purposes of this license, derivative works shall not include works +that remain separable from, or merely link (or bind by name) to the +interfaces of, the Work. + +Works are “made available” under this license by including in or with the Work +either (a) a copyright notice referencing the applicability of +this license to the Work, or (b) a copy of this license. + +2. License Grant + +2.1 Copyright Grant. Subject to the terms and conditions of this license, each +Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, +copyright license to use, reproduce, prepare derivative works of, publicly display, +publicly perform, sublicense and distribute its Work and any resulting derivative +works in any form. + +3. Limitations + +3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under +this license, (b) you include a complete copy of this license with your distribution, +and (c) you retain without modification any copyright, patent, trademark, or +attribution notices that are present in the Work. + +3.2 Derivative Works. You may specify that additional or different terms apply to the use, +reproduction, and distribution of your derivative works of the Work (“Your Terms”) only +if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative +works, and (b) you identify the specific derivative works that are subject to Your Terms. +Notwithstanding Your Terms, this license (including the redistribution requirements in +Section 3.1) will continue to apply to the Work itself. + +3.3 Use Limitation. The Work and any derivative works thereof only may be used or +intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation +and its affiliates may use the Work and any derivative works commercially. +As used herein, “non-commercially” means for research or evaluation purposes only. + +3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor +(including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that +you allege are infringed by any Work, then your rights under this license from +such Licensor (including the grant in Section 2.1) will terminate immediately. + +3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its +affiliates’ names, logos, or trademarks, except as necessary to reproduce +the notices described in this license. + +3.6 Termination. If you violate any term of this license, then your rights under +this license (including the grant in Section 2.1) will terminate immediately. + +4. Disclaimer of Warranty. + +THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. +YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. + +5. Limitation of Liability. + +EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, +WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR +BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, +OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR +INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS +INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY +OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE +POSSIBILITY OF SUCH DAMAGES. + +======================================================================= \ No newline at end of file diff --git a/examples/vlbart/ReftDora/image_video_text_understanding/.gitignore b/examples/vlbart/ReftDora/image_video_text_understanding/.gitignore new file mode 100644 index 0000000..8184185 --- /dev/null +++ b/examples/vlbart/ReftDora/image_video_text_understanding/.gitignore @@ -0,0 +1,168 @@ +# Initially taken from Github's Python gitignore file + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# tests and logs +tests/fixtures/* +!tests/fixtures/sample_text_no_unicode.txt +logs/ +lightning_logs/ +lang_code_data/ +**/slurm* +**/wandb +**/snap +datasets + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# vscode +.vs +.vscode + +# Pycharm +.idea + +# TF code +tensorflow_code + +# Models +proc_data + +# examples +runs +/runs_old +/wandb +/examples/runs +/examples/**/*.args +/examples/rag/sweep + +# data +/data +serialization_dir + +# emacs +*.*~ +debug.env + +# vim +.*.swp + +#ctags +tags + +# pre-commit +.pre-commit* + +# .lock +*.lock diff --git a/examples/vlbart/ReftDora/image_video_text_understanding/LICENSE b/examples/vlbart/ReftDora/image_video_text_understanding/LICENSE new file mode 100644 index 0000000..5c21cd0 --- /dev/null +++ b/examples/vlbart/ReftDora/image_video_text_understanding/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 YI-LIN SUNG + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/examples/vlbart/ReftDora/image_video_text_understanding/README.md b/examples/vlbart/ReftDora/image_video_text_understanding/README.md new file mode 100644 index 0000000..21659ae --- /dev/null +++ b/examples/vlbart/ReftDora/image_video_text_understanding/README.md @@ -0,0 +1,99 @@ +# Finetuning VL-BART on image/video-text understaing tasks using DoRA + +This directory includes the DoRA implementation and guidelines for reproducing the results in our paper. +We evaluate DoRA in a unified multi-task +setup on both image-text and video-text benchmarks following the settings of VL-Adapter. For the image-text tasks, we use four diverse V&L datasets: VQAv2, GQA, NLVR2, and MSCOCO image captioning. For video-text tasks, we use TVQA, How2QA, TVC, and YC2C. + +## Setup +``` +# Create python environment +conda create -n vlt5 python=3.8 +source activate vlt5 + +# Install python dependencies +pip install -r requirements.txt + +# Download T5/BART backbone checkpoint +python download_backbones.py + +# For MSCOCO captioning evaluation (optional; for captioning only) +python -c "import language_evaluation; language_evaluation.download('coco')" +``` + +## Data +```bash +# Store images, features, and annotations +./datasets + COCO/ + images/ + clip_featuers/ + VG/ + images/ + clip_features/ + GQA/ + images/ + clip_features/ + nlvr/ + images/ + clip_features/ + vqa/ + lxmert/ + + video/ + ann/ + vis_features + +# Train VL-T5 with adapters +./VL-T5/ + src/ + multitask.py <= multitask learning on 7 downstream tasks + trainer_base.py <= DoRA implementation +``` + +### Image-text dataset +Please go to [link](https://drive.google.com/file/d/1O_RU1iFh_sbItZCTkOHUrbVIQQ_89Djj/view?usp=sharing) to download the processed CLIP features. We suggest to use [gdrive](https://github.com/prasmussen/gdrive) to download it. Unzip the downloaded file and arrange the folders according to the format demonstrated above. + +If you would like to use dgrive to download the data, please try the following command + +``` +gdrive download 1O_RU1iFh_sbItZCTkOHUrbVIQQ_89Djj +``` + +### Extract your own CLIP features (Not necessary) +Please refer to `feature_extraction` for more details. + +### Video-text dataset +Please go to [VALUE](https://github.com/VALUE-Leaderboard/DataRelease) to download the ViT processed data. + +## Finetuning and Evaluation +### Finetuning VL-BART on Image-text datasets with DoRA (Evaluation included) +``` +bash ./VL-T5/scripts/image/dora.sh 1 +``` +### Finetuning VL-BART on Video-text datasets with DoRA +``` +bash ./VL-T5/scripts/video/dora.sh 1 +``` +### Evaluation of video-text tasks +Submit the generated test submission file strictly following the submission format (including directory layout and file names) specified [here](https://github.com/VALUE-Leaderboard/EvaluationTools) to the [Value benchmark website](https://value-benchmark.github.io/#:~:text=What%20is%20VALUE%3F,understanding%20both%20video%20and%20subtitles.) for evaluation. + +## DoRA Result + +### The multi-task evaluation results on VQA, GQA, NVLR2 and COCO Caption with the VL-BART backbone +| Method | # Params (%) | VQAv2 | GQA | NVLR2 | COCO Cap | Avg | +|-----------------------|---------|--------|--------|-------------|--------------|---------| +| FT | 100 | 66.9 |56.7 | 73.7 |112.0| 77.3| +| LoRA | 5.93 |65.2 |53.6| 71.9| 115.3| 76.5| +| DoRA | 5.96 | 65.8 |54.7 |73.1 |115.9 | **77.4** | + + +### The multi-task evaluation results on TVQA, How2QA, TVC, and YC2C with the VL-BART backbone. +| Method | # Params (%) | TVQA | How2QA| TVC| YC2C | Avg | +|-----------------------|---------|--------|--------|-------------|--------------|---------| +| FT | 100 | 76.3 | 73.9| 45.7| 154.0 | 87.5| +| LoRA | 5.17 | 75.5 | 72.9 | 44.6 | 140.9 | 83.5| +| DoRA | 5.19 | 76.3 | 74.1 | 45.8 | 145.4 | **85.4** | + + +## Acknowledgement +We greatly appreciate the contributions of [VL-Adapter](https://github.com/ylsung/VL_adapter) which has significantly benefited our work. \ No newline at end of file diff --git a/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/scripts/image/dora.sh b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/scripts/image/dora.sh new file mode 100644 index 0000000..b8459c1 --- /dev/null +++ b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/scripts/image/dora.sh @@ -0,0 +1,69 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +export CUDA_VISIBLE_DEVICES=0 +task=multitask + +# or bart +model="bart" + +echo $model + +if [ $model == "t5" ] +then + folder_prefix="VLT5" + backbone="t5-base" + batch_size=400 +elif [ $model == "bart" ] +then + folder_prefix="VLBart" + backbone="facebook/bart-base" + batch_size=300 +fi + +echo $folder_prefix +echo $backbone + +feature=RN101 + +lr=1e-3 + +lora_dim=128 + +project_name=${feature}_LMsingle_dora_${lora_dim}_bs${batch_size}_image224_lora_settings +run_name=tune+lr${lr}_plzplz2 +output=snap/${folder_prefix}_${task}/$run_name + +TOKENIZERS_PARALLELISM=True PYTHONPATH=$PYTHONPATH:./src \ +python -m torch.distributed.launch \ + --nproc_per_node=$1 \ + --master_port=26464 \ + src/${task}.py \ + --distributed --multiGPU \ + --optim adamw \ + --warmup_ratio 0.1 \ + --clip_grad_norm 5 \ + --lr ${lr} \ + --epochs 20 \ + --num_workers 4 \ + --backbone ${backbone} \ + --output $output ${@:2} \ + --num_beams 5 \ + --use_tasks_prompts \ + --batch_size ${batch_size} \ + --valid_batch_size ${batch_size} \ + --use_dora \ + --unfreeze_bias \ + --unfreeze_layer_norms \ + --lora_settings \ + --lora_dim ${lora_dim} \ + --tasks "vqa" \ + --feature ${feature} --n_boxes 36 --downsample \ + --image_size "(224,224)" \ + --project_name $project_name \ + --run_name $run_name diff --git a/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/scripts/video/dora.sh b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/scripts/video/dora.sh new file mode 100644 index 0000000..037c9f6 --- /dev/null +++ b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/scripts/video/dora.sh @@ -0,0 +1,100 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +task=multitask_video + +# or bart +model="bart" + +echo $model + +if [ $model == "t5" ] +then + folder_prefix="VLT5" + backbone="t5-base" + batch_size=300 +elif [ $model == "bart" ] +then + folder_prefix="VLBart" + backbone="facebook/bart-base" + batch_size=50 +fi + +echo $folder_prefix +echo $backbone + +feature=ViT + +batch_size=40 +#50 +lr=2.4e-4 +#3e-4 + +lora_dim=128 + +project_name=${feature}_LMsingle_dora${lora_dim}_bs${batch_size}_image224_video +run_name=dora_lora_setting_${lr}_${lora_dim} +output=snap/${folder_prefix}_${task}/$run_name + +TOKENIZERS_PARALLELISM=True PYTHONPATH=$PYTHONPATH:./src \ +python -m torch.distributed.launch \ + --nproc_per_node=$1 \ + --master_port=26465 \ + src/${task}.py \ + --distributed --multiGPU \ + --optim adamw \ + --warmup_ratio 0.1 \ + --clip_grad_norm 5 \ + --lr ${lr} \ + --epochs 7 \ + --num_workers 4 \ + --backbone ${backbone} \ + --output $output ${@:2} \ + --num_beams 5 \ + --use_dora \ + --unfreeze_bias \ + --lora_settings \ + --lora_dim ${lora_dim} \ + --batch_size ${batch_size} \ + --valid_batch_size ${batch_size} \ + --use_tasks_prompts \ + --tasks "tvqa,how2qa,tvc,yc2c" \ + --feature ${feature} --n_boxes 64 --downsample \ + --image_size "(224,224)" \ + --project_name $project_name \ + --run_name $run_name + +## this is for generating the output for submitting to https://value-benchmark.github.io/#:~:text=What%20is%20VALUE%3F,understanding%20both%20video%20and%20subtitles. +python -m torch.distributed.launch \ + --nproc_per_node=$1 \ + --master_port=26465 \ + src/${task}.py \ + --distributed --multiGPU \ + --optim adamw \ + --warmup_ratio 0.1 \ + --clip_grad_norm 5 \ + --lr ${lr} \ + --epochs 0 \ + --num_workers 4 \ + --backbone ${backbone} \ + --output $output ${@:2} \ + --num_beams 5 \ + --use_dora \ + --load snap/${folder_prefix}_${task}/$run_name/LAST.pth \ + --unfreeze_bias \ + --lora_settings \ + --lora_dim ${lora_dim} \ + --batch_size ${batch_size} \ + --valid_batch_size ${batch_size} \ + --use_tasks_prompts \ + --tasks "tvqa,how2qa,tvc,yc2c" \ + --feature ${feature} --n_boxes 64 --downsample \ + --image_size "(224,224)" \ + --project_name $project_name \ + --run_name $run_name + diff --git a/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/Reft_Injection.ipynb b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/Reft_Injection.ipynb new file mode 100644 index 0000000..957b4a5 --- /dev/null +++ b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/Reft_Injection.ipynb @@ -0,0 +1,2837 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cfd53f5e-8059-4338-bbb4-88abc6202208", + "metadata": {}, + "source": [ + "## Reft + Vision (VLBart) experiments\n", + "\n", + "Does ReFT works with Vision-language? Let's find out with the VQA Task." + ] + }, + { + "cell_type": "markdown", + "id": "3002ccd0-1095-4a01-8395-70f059d425c6", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9bb906f7-b8ba-42ee-a114-bc1dee0378ea", + "metadata": {}, + "outputs": [], + "source": [ + "from trainer_base import TrainerBase\n", + "import torch.backends.cudnn as cudnn\n", + "import torch.multiprocessing as mp\n", + "import os\n", + "import collections\n", + "from pathlib import Path\n", + "from packaging import version\n", + "\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "import torch\n", + "import torch.nn as nn\n", + "import shutil\n", + "from copy import deepcopy\n", + "\n", + "from param import parse_args\n", + "\n", + "from utils import LossMeter\n", + "from dist_utils import reduce_dict\n", + "import wandb\n", + "\n", + "proj_dir = os.path.dirname(os.path.dirname(os.getcwd()))\n", + "\n", + "\n", + "_use_native_amp = False\n", + "_use_apex = False\n", + "\n", + "# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex\n", + "if version.parse(torch.__version__) < version.parse(\"1.6\"):\n", + " from transormers.file_utils import is_apex_available\n", + " if is_apex_available():\n", + " from apex import amp\n", + " _use_apex = True\n", + "else:\n", + " _use_native_amp = True\n", + " from torch.cuda.amp import autocast" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "27b0012c-e0e5-47b1-8f0f-1d0c7debf32d", + "metadata": {}, + "outputs": [], + "source": [ + "import transformers\n", + "\n", + "from transformers.models.bart.modeling_bart import (\n", + " BartConfig,\n", + " ACT2FN,\n", + " shift_tokens_right, _make_causal_mask, _expand_mask\n", + ")\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch.nn import CrossEntropyLoss\n", + "\n", + "from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple\n", + "import copy\n", + "\n", + "from transformers.modeling_outputs import Seq2SeqLMOutput" + ] + }, + { + "cell_type": "markdown", + "id": "c57a9fd4-0059-422a-8d8d-511cf821a8fd", + "metadata": {}, + "source": [ + "### 0. Prerequisites\n", + "- Install Pyvene (peterwz-llava branch) and Pyreft with Python 3.8. Main branch of Pyvene does not work, because to support transformers 4.39, it fails to support the Bart Models used by this notebook.\n", + "- Download all the training datasets using the following command:\n", + " ```\n", + " gdrive download 1O_RU1iFh_sbItZCTkOHUrbVIQQ_89Djj\n", + " ```\n", + " Here is the [google drive link](https://drive.google.com/file/d/1O_RU1iFh_sbItZCTkOHUrbVIQQ_89Djj/view), we suggest you use [gdrive](https://github.com/prasmussen/gdrive) to download it. Put the `datasets` folder directly under `image_video_text_understanding`.\n", + "- Download the image annotations as well by running `wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip`. Unzip and put the files in `image_video_text_understanding/datasets/COCO/images/`." + ] + }, + { + "cell_type": "markdown", + "id": "0089a652-7948-4cea-94ff-6e74b974ac2f", + "metadata": {}, + "source": [ + "### 1. Reft Data (intervention locations)\n", + "\n", + "Reft needs to populate `intervention_locations` together with the text input IDs and image embeddings. Here we calcluate the intervention locations.\n", + "\n", + "The function `get_image_intervention_locations()` creates separate interventions for texts and image inputs. As a reminder, Vision-language Bart concatenates these location tokens together. However, image tokens are not visible to the VLBart model at the time of input (only text `input_ids` are visible). So the image interventions start at the end of `input_ids` and last for `args.n_boxes` long (image features are of length `args.n_boxes`). The number of image and text interventions match the number of interventions defined in the `ReftConfig` in section 1.3.1. We intervene on all encoder layers, and Bart encoder has 6 layers, so there will be 6 text and 6 image interventions if we share weights between the first and the last tokens. If we do not share weights, there will be 12 text and 12 image interventions." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "3e99be92-5b36-451f-947f-a4b6f67b184f", + "metadata": {}, + "outputs": [], + "source": [ + "IGNORE_INDEX = -100\n", + "from transformers import DataCollatorForSeq2Seq\n", + "import torch\n", + "\n", + "def parse_positions(positions: str):\n", + " # parse position\n", + " first_n, last_n = 0, 0\n", + " if \"+\" in positions:\n", + " first_n = int(positions.split(\"+\")[0].strip(\"f\"))\n", + " last_n = int(positions.split(\"+\")[1].strip(\"l\"))\n", + " else:\n", + " if \"f\" in positions:\n", + " first_n = int(positions.strip(\"f\"))\n", + " elif \"l\" in positions:\n", + " last_n = int(positions.strip(\"l\"))\n", + " return first_n, last_n\n", + "\n", + "def get_image_intervention_locations(**kwargs):\n", + " \"\"\"\n", + " This function generates separate intervention locations for images and texts.\n", + " \"\"\"\n", + " # parse kwargs\n", + " share_weights = kwargs[\"share_weights\"] if \"share_weights\" in kwargs else False\n", + " last_text_position = kwargs[\"last_position\"]\n", + " assert \"image_positions\" in kwargs, \"Image positions must be provided\"\n", + " assert \"positions\" in kwargs, \"Text positions must be provided\"\n", + " first_n, last_n = parse_positions(kwargs[\"positions\"])\n", + " first_image_n, last_image_n = parse_positions(kwargs[\"image_positions\"])\n", + "\n", + " num_interventions = kwargs[\"num_interventions\"]\n", + " # `last_offset` is the length of the images (n_boxes).\n", + " # Image tokens are concatenated to the end of the text tokens, i.e. after `last_position` tokens.\n", + " # The true last position of the input is `last_position + last_offset`\n", + " image_offset = kwargs[\"last_offset\"] if \"last_offset\" in kwargs else 0\n", + "\n", + " pad_mode = kwargs[\"pad_mode\"] if \"pad_mode\" in kwargs else \"first\"\n", + " pad_position = -1 if pad_mode == \"first\" else last_text_position + image_offset\n", + " if pad_mode != \"first\" and \"nlvr\" in kwargs[\"tasks\"]:\n", + " pad_position = last_text_position + 2 * image_offset\n", + "\n", + " if share_weights or ((first_n == 0 or last_n == 0) and (first_image_n == 0 or last_image_n == 0)):\n", + " position_list = [i for i in range(first_n)] + \\\n", + " [i for i in range(last_text_position - last_n, last_text_position)]\n", + " image_position_list = [i for i in range(last_text_position, last_text_position + first_image_n)] + \\\n", + " [i for i in range(last_text_position + image_offset - last_image_n, last_text_position + image_offset)]\n", + " # There are 2 images in nlvr, so performing special treatment\n", + " # For this notebook however, we only use vqa\n", + " if \"nlvr\" in kwargs[\"tasks\"]:\n", + " image_position_list += [i for i in range(last_text_position + image_offset, last_text_position + image_offset + first_image_n)] + \\\n", + " [i for i in range(last_text_position + 2 * image_offset - last_image_n, last_text_position + 2 * image_offset)]\n", + " text_len = len(position_list)\n", + " image_len = len(image_position_list)\n", + " if text_len > image_len:\n", + " image_position_list += [pad_position for _ in range(text_len-image_len)]\n", + " else:\n", + " position_list += [pad_position for _ in range(image_len-text_len)]\n", + " intervention_locations = [position_list]*(num_interventions//2) + \\\n", + " [image_position_list]*(num_interventions//2)\n", + " else:\n", + " assert first_n == last_n, \"For now, we only support same first and last positions\"\n", + " left_intervention_locations = [i for i in range(first_n)]\n", + " right_intervention_locations = [i for i in range(last_text_position - last_n, last_text_position)]\n", + " left_image_intervention_locations = [i for i in range(last_text_position, last_text_position + first_image_n)]\n", + " right_image_intervention_locations = [i for i in range(last_text_position + image_offset - last_image_n, last_text_position + image_offset)]\n", + " if \"nlvr\" in kwargs[\"tasks\"]:\n", + " left_image_intervention_locations += [i for i in range(last_text_position + image_offset, last_text_position + image_offset + first_image_n)]\n", + " right_image_intervention_locations += [i for i in range(last_text_position + 2 * image_offset - last_image_n, last_text_position + 2 * image_offset)]\n", + " text_len = len(left_intervention_locations)\n", + " image_len = len(left_image_intervention_locations)\n", + " if text_len > image_len:\n", + " left_image_intervention_locations += [pad_position for _ in range(text_len-image_len)]\n", + " right_image_intervention_locations += [pad_position for _ in range(text_len-image_len)]\n", + " else:\n", + " left_intervention_locations += [pad_position for _ in range(image_len-text_len)]\n", + " right_intervention_locations += [pad_position for _ in range(image_len-text_len)]\n", + "\n", + " intervention_locations = [left_intervention_locations]*(num_interventions//4) + \\\n", + " [right_intervention_locations]*(num_interventions//4) + \\\n", + " [left_image_intervention_locations]*(num_interventions//4) + \\\n", + " [right_image_intervention_locations]*(num_interventions//4)\n", + " return intervention_locations\n" + ] + }, + { + "cell_type": "markdown", + "id": "f1626ef3-000d-48d9-b9f6-c7f5958fda7e", + "metadata": {}, + "source": [ + "Here we also process intervention padding. To collate multiple interventions of different lengths, we create padding interventions that only intervene on padded locations. So these interventions do not impact Reft output. We only use `pad_mode = first`, so you can see that (1) a `1` is prepended to the input IDs at position 0 (2) padding interventions intervene on the position 0." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "94da9380-74b5-4eb4-af3a-67bcd86e4464", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_intervention(\n", + " id: int, \n", + " result: dict, \n", + " tokenizer,\n", + " fields_to_pad = [],\n", + " fields_to_mask = [],\n", + " **kwargs):\n", + " pad_mode = kwargs[\"pad_mode\"]\n", + " # compute intervention locs\n", + " assert \"positions\" in kwargs and \"image_positions\" in kwargs\n", + " intervention_locations = get_image_intervention_locations(**kwargs)\n", + " result[\"intervention_locations\"] = intervention_locations\n", + " result[\"id\"] = id\n", + "\n", + " # add a single padding token BEFORE input_ids and fix everything\n", + " if fields_to_pad is not None:\n", + " if pad_mode == \"first\":\n", + " for field in fields_to_pad:\n", + " if field not in result:\n", + " continue\n", + " if field == \"labels\":\n", + " result[field] = torch.cat((torch.tensor([IGNORE_INDEX,]), result[field]))\n", + " else:\n", + " result[field] = torch.cat((torch.tensor([tokenizer.pad_token_id,]), result[field]))\n", + " result[\"intervention_locations\"] = (torch.IntTensor(result[\"intervention_locations\"]) + 1).tolist()\n", + " result[\"input_length\"] += 1\n", + " elif pad_mode == \"last\":\n", + " for field in fields_to_pad:\n", + " if field not in result:\n", + " continue\n", + " if field == \"labels\":\n", + " result[field] = torch.cat((result[field], torch.tensor([IGNORE_INDEX,])))\n", + " else:\n", + " result[field] = torch.cat((result[field], torch.tensor([tokenizer.pad_token_id,])))\n", + " result[\"input_length\"] += 1\n", + " \n", + " # attention masks\n", + " if len(fields_to_mask) == 1:\n", + " result[\"attention_mask\"] = (result[fields_to_mask[0]] != tokenizer.pad_token_id).int()\n", + " else:\n", + " for field in fields_to_mask:\n", + " result[f\"{field}_mask\"] = (result[field] != tokenizer.pad_token_id).int()\n", + "\n", + " # does not handle subspaces for now\n", + " # print(\"Intervention Locations\", result[\"intervention_locations\"])\n", + " return result" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "18de0d1b-7bac-4cb1-99ad-3ee470158d84", + "metadata": {}, + "outputs": [], + "source": [ + "def reft_post_process(\n", + " out_dict,\n", + " tokenizer,\n", + " idx: int, \n", + " last_position: int, \n", + " args = None,\n", + " pad_mode = \"none\",\n", + " fields_to_pad = [],\n", + " fields_to_mask = []\n", + "):\n", + " # print(\"Out_dict keys:\", out_dict.keys())\n", + " out_dict[\"instruction\"] = tokenizer.decode(\n", + " out_dict[\"input_ids\"], skip_special_tokens=True)\n", + " kwargs = {}\n", + " if args is not None:\n", + " if args.reft_rank != -1:\n", + " kwargs[\"positions\"] = args.positions\n", + " if args.reft_image_rank != -1:\n", + " kwargs[\"image_positions\"] = args.image_positions\n", + " kwargs[\"share_weights\"] = args.share_weights\n", + " layers = [int(l) for l in args.layers.split(\";\")]\n", + " kwargs[\"num_interventions\"] = len(layers) if args.share_weights else 2 * len(layers)\n", + " # Double interventions if creating separate interventions for texts and images\n", + " if args.reft_image_rank != -1 and args.reft_rank != -1:\n", + " kwargs[\"num_interventions\"] *= 2\n", + " # `n_boxes` is the seq length of the image embeddings\n", + " kwargs[\"last_offset\"] = args.n_boxes\n", + " # Only tested `first` \n", + " kwargs[\"pad_mode\"] = pad_mode\n", + " kwargs[\"last_position\"] = last_position\n", + " kwargs[\"tasks\"] = args.prompt\n", + " # print(kwargs)\n", + "\n", + " # print(\"BEFORE:\", out_dict[\"input_ids\"].shape, kwargs[\"last_position\"])\n", + " tokenized = compute_intervention(\n", + " idx, \n", + " out_dict, \n", + " tokenizer,\n", + " fields_to_pad,\n", + " fields_to_mask,\n", + " **kwargs)\n", + " # print(\"AFTER:\", tokenized[\"input_ids\"].shape, tokenized[\"intervention_locations\"])\n", + " return tokenized" + ] + }, + { + "cell_type": "markdown", + "id": "77f4ad91-4798-4f8f-abcf-590b5fb680db", + "metadata": {}, + "source": [ + "Here we collate the `intervention_locations` together with other collating fields, such as the image features, image positions (boxes - for some reason they are all zero tensors in both ReFT and DoRA), and labels. " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c60d24ea-aac4-431b-bce4-3c215045831d", + "metadata": {}, + "outputs": [], + "source": [ + "def keep_intervention_locations(datum):\n", + " new_data = {}\n", + " new_data[\"input_ids\"] = datum[\"input_ids\"]\n", + " new_data[\"intervention_locations\"] = datum[\"intervention_locations\"]\n", + " new_data[\"attention_mask\"] = datum[\"attention_mask\"]\n", + " return new_data\n", + "\n", + "\n", + "def reft_supplemental_data_collator(batch, tokenizer):\n", + " # Create padded `intervention_locations`\n", + " intervene_batch = [keep_intervention_locations(item) for item in batch]\n", + " # The normal data collator for collating other VLBart fields\n", + " intervention_loc_collate_fn = DataCollatorForSeq2Seq(\n", + " tokenizer=tokenizer,\n", + " model=None,\n", + " label_pad_token_id=-100,\n", + " padding=\"longest\"\n", + " )\n", + " \n", + " intervene_batch_entry = intervention_loc_collate_fn(intervene_batch)\n", + "\n", + " batch_entry = {}\n", + " id = []\n", + " instructions = []\n", + " # Collate `instruction` and `id`\n", + " for i, entry in enumerate(batch):\n", + " if 'instruction' in entry:\n", + " instructions.append(entry['instruction'])\n", + " if 'id' in entry:\n", + " id.append(entry['id'])\n", + " import numpy as np\n", + " batch_entry['id'] = np.array(id)\n", + " batch_entry['instruction'] = instructions\n", + " \n", + " # Pad `intervention_locations` with other stuff in the batch\n", + " if \"intervention_locations\" in batch[0]:\n", + " batch_entry[\"intervention_locations\"] = intervene_batch_entry[\"intervention_locations\"]\n", + " return batch_entry\n" + ] + }, + { + "cell_type": "markdown", + "id": "d7f6fdd0-cf51-4f6c-9122-fc115a3e2940", + "metadata": {}, + "source": [ + "ReFT data adds the `intervention_locations` to the VQA fine-tune dataset. Here we add the ReFT specifics to the original VQA dataset (coming from DoRA/VLAdapter). \n", + "\n", + "`VQAFineTuneDataset` integrates with the CLIP features of COCO images. The VQA task contains the mapping from the answers to their labels in the `datasets/vqa/v2_mscoco_train2014_annotations.json` file, and the mapping from question IDs to answers in the `datasets/vqa/karpathy_train.json` file. The VQA images are from the COCO dataset, which should be placed separately in `datasets/COCO/clip_features/data_clip_RN101_att` and `datasets/COCO/clip_features/data_clip_RN101_fc`. You can download all the datasets using the following command:\n", + "\n", + "```\n", + "gdrive download 1O_RU1iFh_sbItZCTkOHUrbVIQQ_89Djj\n", + "```\n", + "\n", + "Here is the [google drive link](https://drive.google.com/file/d/1O_RU1iFh_sbItZCTkOHUrbVIQQ_89Djj/view), we suggest you use [gdrive](https://github.com/prasmussen/gdrive) to download it. Put the `datasets` folder directly under `image_video_text_understanding`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "32011cda-462e-4c73-9ff6-0569b7a338aa", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader, Dataset, Sampler\n", + "import vqa_clip_data as vqa_data\n", + "\n", + "class ReftVQAFineTuneDataset(vqa_data.VQAFineTuneDataset):\n", + " def __init__(self, split='train', raw_dataset=None, rank=-1, topk=-1, verbose=True, args=None, mode='train'):\n", + " super().__init__(split, raw_dataset, rank, topk, verbose, args, mode)\n", + " self.split = split\n", + " \n", + " def __getitem__(self, idx):\n", + "\n", + " out_dict = super().__getitem__(idx)\n", + "\n", + " out_dict[\"instruction\"] = self.tokenizer.decode(\n", + " out_dict['input_ids'], \n", + " skip_special_tokens=True\n", + " )\n", + " last_position = len(out_dict['input_ids']) - 1\n", + " out_dict = reft_post_process(\n", + " out_dict,\n", + " self.tokenizer,\n", + " idx,\n", + " last_position,\n", + " self.args,\n", + " pad_mode=\"first\",\n", + " fields_to_pad=[\"input_ids\"],\n", + " fields_to_mask=[\"input_ids\"]\n", + " )\n", + "\n", + " return out_dict\n", + "\n", + "\n", + " def collate_fn(self, batch):\n", + " batch_entry = super().collate_fn(batch)\n", + " # BEGIN ADD\n", + " extra_batch = reft_supplemental_data_collator(batch, self.tokenizer)\n", + " for k, v in extra_batch.items():\n", + " batch_entry[k] = v\n", + " # END ADD\n", + " # print(\"LOGITS:\", batch_entry[\"logits\"])\n", + " # print(\"LABELS:\", batch_entry[\"labels\"])\n", + " if self.split == \"karpathy_val\" or self.split == \"karpathy_test\":\n", + " print(\"In dataset:\", batch_entry[\"instruction\"], \" \", batch_entry[\"question_ids\"])\n", + "\n", + " return batch_entry\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2cb3d460-8857-4368-9932-101919493bb7", + "metadata": {}, + "outputs": [], + "source": [ + "def get_loader(args, split='karpathy_train', mode='train',\n", + " batch_size=32, workers=4, distributed=False, gpu=0, topk=-1):\n", + "\n", + " verbose = (gpu == 0)\n", + "\n", + " _dset = vqa_data.VQADataset(split, verbose)\n", + " # print(\"Batch size:\", batch_size, \"Num workers:\", workers, \"Topk:\", topk)\n", + "\n", + " dataset = ReftVQAFineTuneDataset(\n", + " split,\n", + " raw_dataset=_dset,\n", + " rank=gpu,\n", + " topk=topk,\n", + " verbose=verbose,\n", + " args=args,\n", + " mode=mode)\n", + " sampler = None\n", + "\n", + " if mode == 'train':\n", + " loader = DataLoader(\n", + " dataset, batch_size=batch_size, shuffle=(sampler is None),\n", + " num_workers=workers, pin_memory=True, sampler=sampler,\n", + " collate_fn=dataset.collate_fn)\n", + " else:\n", + " loader = DataLoader(\n", + " dataset,\n", + " batch_size=batch_size,\n", + " num_workers=workers, pin_memory=True,\n", + " sampler=sampler,\n", + " shuffle=None if (sampler is not None) else False,\n", + " collate_fn=dataset.collate_fn,\n", + " drop_last=False)\n", + "\n", + " if verbose:\n", + " loader.evaluator = vqa_data.VQAEvaluator(_dset)\n", + "\n", + " loader.task = 'vqa'\n", + "\n", + " return loader\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "ccbfcc23-040f-43ba-8cf3-c3cb971cf3f4", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "### 2. Reft Model Replica for VLBart" + ] + }, + { + "cell_type": "markdown", + "id": "790d9a60-89b4-4ea2-802f-db9c23e73bac", + "metadata": {}, + "source": [ + "The forward pass API of ReFT is different from normal - we also need to pass in `intervention_locations`. `VLBartReft` wraps up the `VLBart` model implementation in VLAdapter/DoRA with `intervention_locations` passed in.\n", + "\n", + "In addition, we need to integrate ReFT's trainable parameters into VLBart Model's, so that gradient will propagate." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9d574fd3-aa93-4dc9-a7e8-8a4f23e15da7", + "metadata": {}, + "outputs": [], + "source": [ + "from modeling_bart import VLBart\n", + "class VLBartReft(VLBart):\n", + " def __init__(self, config: BartConfig):\n", + " super().__init__(config)\n", + " from pyreft import get_reft_model\n", + " self.intervenable = get_reft_model(self.model, config.reft_config)\n", + " # print(\"Reft parameters:\", self.intervenable.interventions)\n", + " # self.intervenable.unfreeze_intervention_parameters()\n", + " self.intervenable.print_trainable_parameters()\n", + " # print(\"INTERVENABLE:\", self.intervenable.model)\n", + "\n", + " # Unfreeze the PyVene intervention parameters\n", + " for k, v in self.intervenable.unfreeze_intervention_parameters().items():\n", + " n = k.replace(\".\", \"#\")\n", + " print(n)\n", + " self.register_parameter(n, v)\n", + " \n", + " def forward(\n", + " self,\n", + " input_ids=None,\n", + " attention_mask=None,\n", + "\n", + " vis_inputs=None,\n", + " vis_attention_mask=None,\n", + "\n", + " decoder_input_ids=None,\n", + " decoder_attention_mask=None,\n", + " encoder_outputs=None,\n", + " past_key_values=None,\n", + " inputs_embeds=None,\n", + " decoder_inputs_embeds=None,\n", + " labels=None,\n", + " use_cache=None,\n", + " output_attentions=None,\n", + " output_hidden_states=None,\n", + " return_dict=None,\n", + " task=None,\n", + "\n", + " reduce_loss=False,\n", + " intervention_locations = None,\n", + " **kwargs,\n", + " ):\n", + " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n", + "\n", + " if labels is not None:\n", + " if decoder_input_ids is None:\n", + " decoder_input_ids = shift_tokens_right(\n", + " labels, self.config.pad_token_id, self.config.decoder_start_token_id\n", + " )\n", + " \n", + " if intervention_locations is not None:\n", + " # print(\"Intervention locs not None\")\n", + " # Pyvene forward pass\n", + " intervention_locations = intervention_locations.clone().detach().permute(1, 0, 2)\n", + " _, outputs = self.intervenable(\n", + " {\n", + " \"input_ids\": input_ids,\n", + " \"attention_mask\": attention_mask,\n", + " \"vis_inputs\": vis_inputs,\n", + " \"vis_attention_mask\": vis_attention_mask,\n", + " \"decoder_input_ids\": decoder_input_ids,\n", + " \"decoder_attention_mask\": decoder_attention_mask,\n", + " \"encoder_outputs\": encoder_outputs,\n", + " \"past_key_values\": past_key_values,\n", + " \"inputs_embeds\": inputs_embeds,\n", + " \"decoder_inputs_embeds\": decoder_inputs_embeds,\n", + " \"output_attentions\": output_attentions,\n", + " \"output_hidden_states\": output_hidden_states,\n", + " \"task\": task,\n", + " \"return_dict\": return_dict,\n", + " },\n", + " unit_locations={\"sources->base\": (\n", + " None,\n", + " intervention_locations\n", + " )},\n", + " labels=labels,\n", + " return_dict=False,\n", + " subspaces=None,\n", + " use_cache=use_cache,\n", + " )\n", + " else:\n", + " # print(\"Intervention locs None\")\n", + " outputs = self.model(\n", + " input_ids,\n", + " attention_mask=attention_mask,\n", + "\n", + " vis_inputs=vis_inputs,\n", + " vis_attention_mask=vis_attention_mask,\n", + "\n", + " decoder_input_ids=decoder_input_ids,\n", + " encoder_outputs=encoder_outputs,\n", + " decoder_attention_mask=decoder_attention_mask,\n", + " past_key_values=past_key_values,\n", + " inputs_embeds=inputs_embeds,\n", + " decoder_inputs_embeds=decoder_inputs_embeds,\n", + " use_cache=use_cache,\n", + " output_attentions=output_attentions,\n", + " output_hidden_states=output_hidden_states,\n", + " return_dict=return_dict,\n", + " task=task,\n", + " )\n", + "\n", + " # print(\"Outputs:\", outputs)\n", + " lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias\n", + " \n", + " masked_lm_loss = None\n", + " # print(\"LOGITS:\", lm_logits)\n", + " # print(\"LABELS\", labels)\n", + " if labels is not None:\n", + " # loss_fct = CrossEntropyLoss()\n", + " if reduce_loss:\n", + " loss_fct = CrossEntropyLoss(ignore_index=-100)\n", + " else:\n", + " loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')\n", + " masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))\n", + " \n", + " if not return_dict:\n", + " output = (lm_logits,) + outputs[1:]\n", + " return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output\n", + "\n", + " # if masked_lm_loss is not None and len(masked_lm_loss) > 1:\n", + " # masked_lm_loss = masked_lm_loss[0]\n", + " # print(\"LOSS 0:\", masked_lm_loss)\n", + "\n", + " return Seq2SeqLMOutput(\n", + " loss=masked_lm_loss,\n", + " logits=lm_logits,\n", + " past_key_values=outputs.past_key_values,\n", + " decoder_hidden_states=outputs.decoder_hidden_states,\n", + " decoder_attentions=outputs.decoder_attentions,\n", + " cross_attentions=outputs.cross_attentions,\n", + " encoder_last_hidden_state=outputs.encoder_last_hidden_state,\n", + " encoder_hidden_states=outputs.encoder_hidden_states,\n", + " encoder_attentions=outputs.encoder_attentions,\n", + " )\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "609328e2-5618-4e2e-a48d-c36829291bc5", + "metadata": {}, + "source": [ + "Here we calculate the VQA training loss (whether the model output matches the label, weighted by the score). We also specify the correct VQA generation parameters here. " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8e885ce0-d164-44b9-883c-3cedbbbd0e04", + "metadata": {}, + "outputs": [], + "source": [ + "class VLBartVQA(VLBartReft):\n", + " def __init__(self, config, num_answers=None, label2ans=None):\n", + " super().__init__(config)\n", + "\n", + " self.num_answers = num_answers\n", + " self.label2ans = label2ans\n", + " self.bce_loss = nn.BCEWithLogitsLoss()\n", + "\n", + " def train_step(self, batch):\n", + "\n", + " device = next(self.parameters()).device\n", + "\n", + " batch = self.vis_forward(batch, device)\n", + " task = batch[\"task\"]\n", + "\n", + " vis_feats = batch['vis_feats'].to(device)\n", + " input_ids = batch['input_ids'].to(device)\n", + " vis_pos = batch['boxes'].to(device)\n", + " intervention_locations = batch['intervention_locations'].to(device)\n", + "\n", + " lm_labels = batch[\"target_ids\"].to(device)\n", + "\n", + " output = self(\n", + " input_ids=input_ids,\n", + " vis_inputs=(vis_feats, vis_pos),\n", + " labels=lm_labels,\n", + " return_dict=True,\n", + " task=task,\n", + " intervention_locations=intervention_locations\n", + " )\n", + " assert 'loss' in output\n", + "\n", + " lm_mask = (lm_labels != -100).float()\n", + " B, L = lm_labels.size()\n", + "\n", + " loss = output['loss']\n", + "\n", + " loss = loss.view(B, L) * lm_mask\n", + "\n", + " loss = loss.sum(dim=1) / lm_mask.sum(dim=1).clamp(min=1) # B\n", + "\n", + " loss = loss * batch['scores'].to(device=device)\n", + "\n", + " loss = loss.mean()\n", + " \n", + " # print(\"LOSS 1:\", batch[\"scores\"], loss.item())\n", + " result = {\n", + " 'loss': loss\n", + " }\n", + "\n", + " return result\n", + "\n", + " @torch.no_grad()\n", + " def test_step(self, batch, **kwargs):\n", + " self.eval()\n", + " device = next(self.parameters()).device\n", + "\n", + " batch = self.vis_forward(batch, device)\n", + "\n", + " vis_feats = batch['vis_feats'].to(device)\n", + " input_ids = batch['input_ids'].to(device)\n", + " vis_pos = batch['boxes'].to(device)\n", + " task = batch[\"task\"]\n", + " intervention_locations = batch['intervention_locations'].to(device)\n", + "\n", + " result = {}\n", + " generation_args = {\n", + " \"base\": {\n", + " \"input_ids\":input_ids,\n", + " \"vis_inputs\":(vis_feats, vis_pos),\n", + " \"task\":task,\n", + " **kwargs\n", + " },\n", + " \"unit_locations\": {\"sources->base\": (None, \n", + " intervention_locations.permute(1, 0, 2))},\n", + " \"intervene_on_prompt\": True,\n", + " \"eos_token_id\": self.tokenizer.eos_token_id,\n", + " \"early_stopping\": True,\n", + " \"model\": self,\n", + " }\n", + " # print(\"Generating...\", input_ids.shape, intervention_locations)\n", + " # TODO: temperature, top_p, top_k\n", + " # print(\"GENERATE MODEL:\", self.intervenable.model)\n", + " _, output = self.intervenable.generate(**generation_args)\n", + " generated_sents = self.tokenizer.batch_decode(output, skip_special_tokens=True)\n", + " result['token_ids'] = output\n", + " result['pred_ans'] = generated_sents\n", + "\n", + " return result\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "4684f989-0756-47f1-a82b-305d4a3d38c4", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "### 3. Multitask VLBart Trainer" + ] + }, + { + "cell_type": "markdown", + "id": "0ba702d3-ab30-4e98-acc7-54a2e288414b", + "metadata": {}, + "source": [ + "#### 1.3.1 ReftConfig\n", + "\n", + "ReFT model needs `ReftConfig` to properly initialize. Here we specify the text and image intervention specs. They are all separate.\n", + "\n", + "Also, we change the weight decay of ReFT parameters to 0. ReFT does not like wd, but other parameters in the training of VLBart do (visual embeddings, layer norms, for example)." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "e3a9aa16-891b-4b60-a536-c3dae3fbc573", + "metadata": {}, + "outputs": [], + "source": [ + "from pyreft import ReftConfig, LoreftIntervention, TaskType\n", + "\n", + "class ReftTrainer(TrainerBase):\n", + " def __init__(self, args, train_loader=None, val_loader=None, test_loader=None, train=True):\n", + " super().__init__(\n", + " args,\n", + " train_loader=train_loader,\n", + " val_loader=val_loader,\n", + " test_loader=test_loader,\n", + " train=train)\n", + "\n", + " def create_reft_config(self, config):\n", + " args = self.args\n", + " layers = args.layers\n", + " # ReFT layers - right now only \"all\" works properly\n", + " # TODO: properly process \"layers\" when it is not \"all\"\n", + " if layers != \"all\":\n", + " layers = [int(l) for l in layers.split(\";\")]\n", + " else:\n", + " layers = [l for l in range(config.num_hidden_layers)]\n", + " if '+' in self.args.positions and not args.share_weights:\n", + " layers += layers\n", + " \n", + " image_rank = args.reft_image_rank\n", + " text_rank = args.reft_rank\n", + " embed_dim = args.mid_dim\n", + "\n", + " # print(\"REFT PARAMS:\",embed_dim, rank, args.dropout)\n", + " representations = []\n", + " # Text interventions\n", + " if text_rank != -1:\n", + " representations += [{\n", + " \"layer\": l, \"component\": \"block_output\",\n", + " \"low_rank_dimension\": text_rank,\n", + " \"intervention\": LoreftIntervention(\n", + " embed_dim=embed_dim, low_rank_dimension=text_rank,\n", + " dropout=args.reft_dropout, dtype=torch.float32, act_fn=None, device=\"cuda\",\n", + " add_bias=True\n", + " )\n", + " } for l in layers]\n", + " # Image interventions\n", + " if image_rank != -1:\n", + " representations += [{\n", + " \"layer\": l, \"component\": \"block_output\",\n", + " \"low_rank_dimension\": image_rank,\n", + " \"intervention\": LoreftIntervention(\n", + " embed_dim=embed_dim, low_rank_dimension=image_rank,\n", + " dropout=args.reft_image_dropout, dtype=torch.float32, act_fn=None, device=\"cuda\",\n", + " add_bias=True\n", + " )\n", + " } for l in layers]\n", + " reft_config = ReftConfig(representations=representations)\n", + " print(reft_config)\n", + " return reft_config\n", + "\n", + " def create_config(self):\n", + " config = super().create_config()\n", + " setattr(config, \"reft_config\", self.create_reft_config(config))\n", + " return config\n", + "\n", + " def create_optimizer_and_scheduler(self):\n", + " if self.verbose:\n", + " print('Building Optimizer')\n", + "\n", + " lr_scheduler = None\n", + "\n", + " from transformers.optimization import AdamW, get_linear_schedule_with_warmup\n", + "\n", + " # Added \"#unit#pos\" to `no_decay` to keep ReFT intervention's weight decay to 0\n", + " # Bart's bias and layer norm's weight decay is 0, others are not zero \n", + " no_decay = [\"bias\", \"LayerNorm.weight\", \"#unit#pos\"]\n", + "\n", + " if 'adamw' in self.args.optim:\n", + " optimizer_grouped_parameters = [\n", + " {\n", + " \"params\": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],\n", + " \"weight_decay\": self.args.weight_decay,\n", + " },\n", + " {\n", + " \"params\": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],\n", + " \"weight_decay\": 0.0,\n", + " },\n", + " ]\n", + " optim = AdamW(optimizer_grouped_parameters,\n", + " lr=self.args.lr, eps=self.args.adam_eps)\n", + "\n", + " else:\n", + " # print(\"Parameters:\", self.model.named_parameters())\n", + " optimizer_grouped_parameters = [\n", + " {\n", + " \"params\": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],\n", + " \"weight_decay\": self.args.weight_decay,\n", + " },\n", + " {\n", + " \"params\": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],\n", + " \"weight_decay\": 0.0,\n", + " },\n", + " ]\n", + " optim = self.args.optimizer(optimizer_grouped_parameters, self.args.lr)\n", + "\n", + " batch_per_epoch = len(self.train_loader)\n", + " t_total = batch_per_epoch // self.args.gradient_accumulation_steps * self.args.epochs\n", + " warmup_ratio = self.args.warmup_ratio\n", + " warmup_iters = int(t_total * warmup_ratio)\n", + " if self.verbose:\n", + " print(\"Batch per epoch: %d\" % batch_per_epoch)\n", + " print(\"Total Iters: %d\" % t_total)\n", + " print('Warmup ratio:', warmup_ratio)\n", + " print(\"Warm up Iters: %d\" % warmup_iters)\n", + "\n", + " lr_scheduler = get_linear_schedule_with_warmup(optim, warmup_iters, t_total)\n", + "\n", + " return optim, lr_scheduler" + ] + }, + { + "cell_type": "markdown", + "id": "4e5708c2-6eed-4162-b5a3-bab4d1288020", + "metadata": {}, + "source": [ + "#### 1.3.2 Reft images trainer\n", + "\n", + "This class is a complete VLBart trainer, integrated with ReFT. We\n", + "\n", + "- Unfreezed ReFT's parameters\n", + "- Updated the model's config with ReFT\n", + "- Performed gradient clipping (needed for visual embeddings)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "495a45bf-dfdb-4554-ada9-9393ac69406d", + "metadata": {}, + "outputs": [], + "source": [ + "class Trainer(ReftTrainer):\n", + " def __init__(self, args, train_loader=None, val_loader=None, test_loader=None, train=True):\n", + " super().__init__(\n", + " args,\n", + " train_loader=train_loader,\n", + " val_loader=val_loader,\n", + " test_loader=test_loader,\n", + " train=train)\n", + "\n", + " model_kwargs = {}\n", + " if 'bart' in args.backbone:\n", + " model_class = VLBartVQA\n", + "\n", + " config = self.create_config()\n", + " self.tokenizer = self.create_tokenizer()\n", + "\n", + " if 'bart' in self.args.tokenizer:\n", + " num_added_toks = 0\n", + " if config.use_vis_order_embedding:\n", + " additional_special_tokens = [f'' for i in range(100-1, -1, -1)] + \\\n", + " [f'' for i in range(100-1, -1, -1)]\n", + " special_tokens_dict = {'additional_special_tokens': additional_special_tokens}\n", + " num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict)\n", + "\n", + " config.default_obj_order_ids = self.tokenizer.convert_tokens_to_ids([f'' for i in range(100)])\n", + "\n", + " self.model = self.create_model(model_class, config, **model_kwargs)\n", + "\n", + " if 't5' in self.args.tokenizer:\n", + " self.model.resize_token_embeddings(self.tokenizer.vocab_size)\n", + " elif 'bart' in self.args.tokenizer:\n", + " self.model.resize_token_embeddings(self.model.model.shared.num_embeddings + num_added_toks)\n", + "\n", + " self.model.tokenizer = self.tokenizer\n", + " if 't5' in self.args.tokenizer or 'bart' in self.args.tokenizer:\n", + " self.model.true_id = self.tokenizer('true', add_special_tokens=False).input_ids[0]\n", + " self.model.false_id = self.tokenizer('false', add_special_tokens=False).input_ids[0]\n", + "\n", + " # Load Checkpoint\n", + " self.start_epoch = None\n", + " if args.load is not None:\n", + " ckpt_path = args.load\n", + " self.load_checkpoint(ckpt_path)\n", + " if self.args.from_scratch:\n", + " self.init_weights()\n", + "\n", + " # GPU Options\n", + " print(f'Model Launching at GPU {self.args.gpu}')\n", + " if self.verbose:\n", + " from time import time\n", + " start = time()\n", + " self.model = self.model.to(args.gpu)\n", + " \n", + " # Only thing changed: set device to cuda, and unfreeze ReFT params\n", + "\n", + " self.model.intervenable.set_device(self.model.model.device)\n", + "\n", + " self.freeze_whole_model() # freeze whole parameters first\n", + " self.unfreeze_parameters() # unfreeze selected parameters\n", + " self.model.intervenable.unfreeze_intervention_parameters()\n", + " # print(self.model)\n", + " self.percent_updated_parameters = self.print_trainable_params_percentage(self.model)\n", + "\n", + " # Optimizer\n", + " if train:\n", + " self.optim, self.lr_scheduler = self.create_optimizer_and_scheduler()\n", + "\n", + " if self.args.fp16 and _use_native_amp:\n", + " self.scaler = torch.cuda.amp.GradScaler()\n", + " elif _use_apex:\n", + " self.model, self.optim = amp.initialize(\n", + " self.model, self.optim, opt_level='O1', verbosity=self.verbose)\n", + "\n", + " if self.verbose:\n", + " print(f'It took {time() - start:.1f}s')\n", + "\n", + " def train(self):\n", + " if self.verbose:\n", + " vqa_loss_meter = LossMeter()\n", + " # best_eval_loss = 9595.\n", + " quesid2ans = {}\n", + " best_vqa_valid = 0.\n", + " best_vqa_epoch = 0\n", + "\n", + " wandb.init(project=self.args.project_name)\n", + " wandb.run.name = self.args.run_name\n", + " wandb.config.update(self.args)\n", + " wandb.watch(self.model)\n", + " wandb.log(\n", + " {\"percent of updated parameters (%)\": self.percent_updated_parameters}\n", + " )\n", + "\n", + " src_dir = os.path.dirname(os.getcwd())\n", + " base_path = os.path.dirname(src_dir)\n", + " src_dir = str(src_dir)\n", + " wandb.save(os.path.join(src_dir + \"/*.py\"), base_path=base_path)\n", + "\n", + " global_step = 0\n", + " for epoch in range(self.args.epochs):\n", + " if self.start_epoch is not None:\n", + " epoch += self.start_epoch\n", + " self.model.train()\n", + " self.partial_eval()\n", + "\n", + " if self.verbose:\n", + " pbar = tqdm(total=len(self.train_loader), ncols=250)\n", + "\n", + " epoch_results = {\n", + " 'loss': 0.,\n", + " }\n", + "\n", + " task_counter = {\n", + " 'vqa': 0,\n", + " }\n", + "\n", + " # vqa\n", + " quesid2ans = {}\n", + " train_acc = 0.\n", + " # train_acc_steps = int(len(self.train_loader) * 0.05)\n", + " # last_acc_step = 0\n", + "\n", + " for step_i, batch in enumerate(self.train_loader):\n", + "\n", + " # print(f'GPU{self.args.gpu} inside training loop')\n", + " # print(batch)\n", + " task = batch['task']\n", + " # if self.verbose:\n", + " # print('task', task)\n", + " task_counter[task] += 1\n", + "\n", + " batch['log_train_accuracy'] = self.args.log_train_accuracy\n", + "\n", + " # self.optim.zero_grad()\n", + " if self.args.fp16 and _use_native_amp:\n", + " with autocast():\n", + " results = self.model.train_step(batch)\n", + " else:\n", + " results = self.model.train_step(batch)\n", + "\n", + " loss = results['loss']\n", + "\n", + " if self.args.fp16 and _use_native_amp:\n", + " self.scaler.scale(loss).backward()\n", + " elif self.args.fp16 and _use_apex:\n", + " with amp.scale_loss(loss, self.optim) as scaled_loss:\n", + " scaled_loss.backward()\n", + " else:\n", + " loss.backward()\n", + "\n", + " # print(f'GPU{self.args.gpu} after backward')\n", + "\n", + " loss = loss.detach()\n", + "\n", + " # Update Parameters\n", + " if self.args.clip_grad_norm > 0:\n", + " if self.args.fp16 and _use_native_amp:\n", + " self.scaler.unscale_(self.optim)\n", + " torch.nn.utils.clip_grad_norm_(\n", + " self.model.parameters(), self.args.clip_grad_norm)\n", + " elif self.args.fp16 and _use_apex:\n", + " torch.nn.utils.clip_grad_norm_(amp.master_params(\n", + " self.optim), self.args.clip_grad_norm)\n", + " else:\n", + " torch.nn.utils.clip_grad_norm_(\n", + " self.model.parameters(), self.args.clip_grad_norm)\n", + "\n", + " if self.args.fp16 and _use_native_amp:\n", + " self.scaler.step(self.optim)\n", + " self.scaler.update()\n", + " else:\n", + " self.optim.step()\n", + "\n", + " if self.lr_scheduler:\n", + " self.lr_scheduler.step()\n", + " for param in self.model.parameters():\n", + " param.grad = None\n", + "\n", + " global_step += 1\n", + "\n", + " for k, v in results.items():\n", + " if k in epoch_results:\n", + " epoch_results[k] += v.item()\n", + "\n", + " if self.lr_scheduler:\n", + " if version.parse(torch.__version__) >= version.parse(\"1.4\"):\n", + " lr = self.lr_scheduler.get_last_lr()[0]\n", + " else:\n", + " lr = self.lr_scheduler.get_lr()[0]\n", + " else:\n", + " try:\n", + " lr = self.optim.get_lr()[0]\n", + " except AttributeError:\n", + " lr = self.args.lr\n", + "\n", + " if self.verbose:\n", + " if task == 'vqa':\n", + " vqa_loss_meter.update(loss.item())\n", + "\n", + " desc_str = f'Epoch {epoch} | LR {lr:.6f}'\n", + "\n", + " desc_str += f\" |\"\n", + " if 'vqa' in self.args.tasks:\n", + " desc_str += f\" VQA {task_counter['vqa']}\"\n", + " if len(vqa_loss_meter) > 0:\n", + " desc_str += f' | VQA Loss {vqa_loss_meter.val:4f}'\n", + "\n", + " pbar.set_description(desc_str)\n", + " pbar.update(1)\n", + "\n", + " if self.verbose:\n", + " pbar.close()\n", + "\n", + " if self.args.log_train_accuracy:\n", + " train_score_dict = {\n", + " 'n_correct': n_correct,\n", + " 'n_total': n_total\n", + " }\n", + " train_score_dict = reduce_dict(train_score_dict, self.args.gpu)\n", + "\n", + " if self.verbose:\n", + " # Validation\n", + " log_str = ''\n", + " wandb_log_dict = {}\n", + "\n", + " if 'vqa' in self.args.tasks:\n", + " # VQA\n", + " vqa_val_loader = self.val_loader['vqa']\n", + " score_dict = self.vqa_evaluate(vqa_val_loader)\n", + " valid_score = score_dict['topk_score'] * 100.\n", + " valid_score_raw = score_dict['overall']\n", + " if valid_score_raw > best_vqa_valid or epoch == 0:\n", + " best_vqa_valid = valid_score_raw\n", + " best_vqa_epoch = epoch\n", + " # self.save(\"VQA_BEST\")\n", + " log_str += f\"VQA\"\n", + " log_str += \"\\nEpoch %d: Valid Raw %0.2f Topk %0.2f\" % (epoch, valid_score_raw, valid_score)\n", + " log_str += \"\\nEpoch %d: Best Raw %0.2f\\n\" % (best_vqa_epoch, best_vqa_valid)\n", + " wandb_log_dict['VQA/Valid/score'] = valid_score\n", + " wandb_log_dict['VQA/Valid/raw_score'] = score_dict['overall']\n", + " \n", + " wandb.log(wandb_log_dict, step=epoch)\n", + "\n", + " print(log_str)\n", + "\n", + " # Test Set\n", + " if self.verbose:\n", + " self.save(\"LAST\")\n", + "\n", + " log_str = ''\n", + " wandb_log_dict = {}\n", + "\n", + " if 'vqa' in self.args.tasks:\n", + " # VQA\n", + " vqa_test_loader = self.test_loader['vqa']\n", + " evaluator = vqa_test_loader.evaluator\n", + " dump_path = os.path.join(self.args.output, 'karpathy_test_predict.json')\n", + " quesid2ans = self.vqa_predict(vqa_test_loader, dump_path)\n", + " wandb.save(dump_path, base_path=self.args.output)\n", + "\n", + " acc_dict_all = evaluator.evaluate_raw(quesid2ans)\n", + " acc_dict_answerable = evaluator.evaluate_raw(quesid2ans, is_topk_optimal=True)\n", + " acc_dict_unanswerable = evaluator.evaluate_raw(quesid2ans, is_topk_optimal=False)\n", + "\n", + " wandb_log_dict['VQA/Test/overall'] = acc_dict_all['overall']\n", + " wandb_log_dict['VQA/Test/topk_optimal'] = acc_dict_answerable['overall']\n", + " wandb_log_dict['VQA/Test/topk_not_optimal'] = acc_dict_unanswerable['overall']\n", + "\n", + " if self.test_loader.get(\"vqa_submit\", None):\n", + " vqa_submit_test_loader = self.test_loader['vqa_submit']\n", + " dump_path = os.path.join(self.args.output, 'vqa_submit.json')\n", + " self.vqa_predict(vqa_submit_test_loader, dump_path=dump_path)\n", + " wandb.save(dump_path, base_path=self.args.output)\n", + "\n", + " print(log_str)\n", + " wandb.log(wandb_log_dict, step=self.args.epochs)\n", + "\n", + " wandb.log({'finished': True})\n", + "\n", + " def vqa_predict(self, loader, dump_path=None):\n", + " self.model.eval()\n", + " with torch.no_grad():\n", + " quesid2ans = {}\n", + "\n", + " gen_kwargs = {}\n", + " gen_kwargs['num_beams'] = 1\n", + "\n", + " for i, batch in enumerate(tqdm(loader, ncols=150, desc=\"VQA Validation\")):\n", + "\n", + " if self.args.distributed:\n", + " results = self.model.module.test_step(batch, **gen_kwargs)\n", + " else:\n", + " results = self.model.test_step(batch, **gen_kwargs)\n", + "\n", + " pred_ans = results['pred_ans']\n", + " ques_ids = batch['question_ids']\n", + "\n", + " for qid, ans in zip(ques_ids, pred_ans):\n", + " quesid2ans[qid] = ans\n", + "\n", + " if dump_path is not None:\n", + " loader.evaluator.dump_result(quesid2ans, dump_path)\n", + " return quesid2ans\n", + "\n", + " def vqa_evaluate(self, loader, dump_path=None):\n", + " evaluator = loader.evaluator\n", + " quesid2ans = self.vqa_predict(loader, dump_path)\n", + "\n", + " acc_dict = evaluator.evaluate_raw(quesid2ans)\n", + "\n", + " topk_score = evaluator.evaluate(quesid2ans)\n", + " acc_dict['topk_score'] = topk_score\n", + "\n", + " return acc_dict" + ] + }, + { + "cell_type": "markdown", + "id": "f33b4ae2-8de0-4362-a759-bfb047854f03", + "metadata": {}, + "source": [ + "### 4. Main Worker, Params, and Run" + ] + }, + { + "cell_type": "markdown", + "id": "28e82e41-d794-4f8d-b2a0-30cffc1dc694", + "metadata": {}, + "source": [ + "Here we specify all the parameters for VLBart and ReFT. Note that we unfroze the bias and the layer norms, as well as the visual embeddings. For the detailed modules that we tune, see `trainer_base.py`." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "aef8517f-aa1f-41d3-a01d-0f5c8edb7c8e", + "metadata": {}, + "outputs": [], + "source": [ + "cudnn.benchmark = True\n", + "args = parse_args(False)\n", + "ngpus_per_node = torch.cuda.device_count()\n", + "args.world_size = ngpus_per_node" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "b1c6a607-bc07-4236-be0a-ed6af23ac2b3", + "metadata": {}, + "outputs": [], + "source": [ + "args.distributed = False\n", + "args.nproc_per_node = 1\n", + "args.master_port = 26464\n", + "args.multiGPU = True\n", + "args.optim = \"adamw\"\n", + "args.warmup_ratio = 0.1\n", + "args.clip_grad_norm = 5\n", + "args.weight_decay = 0.01\n", + "args.lr = 1e-3\n", + "args.epochs = 3\n", + "args.num_workers = 4\n", + "args.backbone = \"facebook/bart-base\"\n", + "args.output = \"snap/VLBart_dora_reft/test/\"\n", + "args.num_beams = 5\n", + "args.use_tasks_prompts = True\n", + "args.train_topk = 10000\n", + "args.valid_topk = 4\n", + "args.batch_size = 100\n", + "args.valid_batch_size = 1\n", + "# args.use_dora = True\n", + "args.unfreeze_bias = True\n", + "args.unfreeze_layer_norms = True\n", + "# args.lora_settings = True\n", + "# args.lora_dim = 128\n", + "args.tasks = \"vqa\"\n", + "args.dropout = 0.00\n", + "args.reft_dropout = 0.00\n", + "args.reft_image_dropout = 0.00\n", + "args.reft_rank = 4\n", + "args.reft_image_rank = 64\n", + "args.positions = \"f3+l3\"\n", + "args.image_positions = \"f3+l3\"\n", + "\n", + "args.feature = \"RN101\"\n", + "args.n_boxes = 36\n", + "args.downsample = True\n", + "args.image_size = \"(224,224)\"\n", + "args.project_name = \"Test\"\n", + "args.run_name = \"tune+lr1e-3\"\n", + "args.local_rank = 0\n", + "args.feature_type = \"RN101\"\n", + "os.environ['RANK'] = '0'\n", + "os.environ['WORLD_SIZE'] = '1'\n", + "os.environ['MASTER_ADDR'] = '127.0.0.1'\n", + "os.environ['MASTER_PORT'] = '26464'" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "b0accb8d-5a6f-4766-933b-50141249e0eb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Configurations\n", + "{'adam_beta1': 0.9,\n", + " 'adam_beta2': 0.999,\n", + " 'adam_eps': 1e-06,\n", + " 'add_adapter_cross_attn': True,\n", + " 'add_layer_norm_after_adapter': False,\n", + " 'add_layer_norm_before_adapter': False,\n", + " 'additional_visual_embedding_layers': 0,\n", + " 'answer_normalize': False,\n", + " 'backbone': 'facebook/bart-base',\n", + " 'batch_size': 100,\n", + " 'caption_cocoonly': True,\n", + " 'caption_only': False,\n", + " 'classifier': False,\n", + " 'clip_grad_norm': 5,\n", + " 'cls_task': 'tinyimagenet',\n", + " 'coco_only': False,\n", + " 'comment': '',\n", + " 'decoder_prompt_len': 0,\n", + " 'deepspeed': None,\n", + " 'distributed': False,\n", + " 'do_lower_case': False,\n", + " 'dora_simple': False,\n", + " 'downsample': True,\n", + " 'dropout': 0.0,\n", + " 'dry': False,\n", + " 'efficient_unique_hyper_net': False,\n", + " 'encoder_prompt_len': 0,\n", + " 'epochs': 3,\n", + " 'expand_vis_embedding': False,\n", + " 'factorized_phm': True,\n", + " 'feat_dim': 2048,\n", + " 'feature': 'RN101',\n", + " 'feature_type': 'RN101',\n", + " 'fp16': False,\n", + " 'freeze_bn_statistics': False,\n", + " 'freeze_ln_statistics': False,\n", + " 'from_scratch': False,\n", + " 'gen_max_length': 20,\n", + " 'gradient_accumulation_steps': 1,\n", + " 'hypercomplex_division': 4,\n", + " 'image_positions': 'f3+l3',\n", + " 'image_size': '(224,224)',\n", + " 'individual_vis_layer_norm': True,\n", + " 'lambda_z': 0.001,\n", + " 'layers': '0;1;2;3;4;5',\n", + " 'load': None,\n", + " 'load_lxmert_qa': None,\n", + " 'local_rank': 0,\n", + " 'log_train_accuracy': False,\n", + " 'lora_alpha': 32,\n", + " 'lora_dim': 4,\n", + " 'lora_settings': False,\n", + " 'losses': 'lm,obj,attr,feat',\n", + " 'low_rank_rank': 1,\n", + " 'lr': 0.001,\n", + " 'master_port': 26464,\n", + " 'max_n_boxes': 36,\n", + " 'max_text_length': 20,\n", + " 'mid_dim': 768,\n", + " 'multiGPU': True,\n", + " 'multitask_sampling': 'roundrobin',\n", + " 'n_boxes': 36,\n", + " 'n_ground': 1,\n", + " 'n_image_tokens': 4,\n", + " 'nproc_per_node': 1,\n", + " 'num_beams': 5,\n", + " 'num_workers': 4,\n", + " 'obj_mask_rate': 0.15,\n", + " 'oneddownsample': False,\n", + " 'optim': 'adamw',\n", + " 'optimizer': 'adamw',\n", + " 'oscar_tags': False,\n", + " 'output': 'snap/VLBart_dora_reft/test/',\n", + " 'phm_init_range': 0.01,\n", + " 'phm_rank': 1,\n", + " 'pos_dim': 4,\n", + " 'positions': 'f3+l3',\n", + " 'post_prompt': '',\n", + " 'prefix': None,\n", + " 'project_name': 'Test',\n", + " 'projected_task_embedding_dim': -1,\n", + " 'prompt': 'vqa: ',\n", + " 'raw_label': False,\n", + " 'reduction_factor': 16,\n", + " 'reft_dropout': 0.0,\n", + " 'reft_image_dropout': 0.0,\n", + " 'reft_image_rank': 64,\n", + " 'reft_rank': 4,\n", + " 'remove_bn_vis_adapter': False,\n", + " 'run_name': 'tune+lr1e-3',\n", + " 'seed': 9595,\n", + " 'share_down_sampler': False,\n", + " 'share_up_sampler': False,\n", + " 'share_vis_lang_layer_norm': False,\n", + " 'share_weights': False,\n", + " 'shared_phm_rule': True,\n", + " 'shared_phm_rule_over_tasks': False,\n", + " 'sparse_sample': False,\n", + " 'submit': False,\n", + " 'tasks': 'vqa',\n", + " 'test': None,\n", + " 'test_answerable': False,\n", + " 'test_only': False,\n", + " 'testing': False,\n", + " 'tokenizer': None,\n", + " 'track_z': False,\n", + " 'train': 'train',\n", + " 'train_topk': 10000,\n", + " 'unfreeze_batch_norms': False,\n", + " 'unfreeze_bias': True,\n", + " 'unfreeze_decoder_layer_norms': False,\n", + " 'unfreeze_encoder_layer_norms': False,\n", + " 'unfreeze_language_model': False,\n", + " 'unfreeze_layer_norms': True,\n", + " 'unfreeze_lm_head': False,\n", + " 'unfreeze_vis_encoder': False,\n", + " 'unfreeze_vis_last_layer': False,\n", + " 'unique_hyper_net': False,\n", + " 'use_adam_for_visual': False,\n", + " 'use_adapter': False,\n", + " 'use_attn_prefix': False,\n", + " 'use_compacter': False,\n", + " 'use_data_augmentation': False,\n", + " 'use_dora': False,\n", + " 'use_hyperformer': False,\n", + " 'use_lm_head_adapter': False,\n", + " 'use_lora': False,\n", + " 'use_lradapter': False,\n", + " 'use_separate_optimizer_for_visual': False,\n", + " 'use_single_adapter': False,\n", + " 'use_single_lora': False,\n", + " 'use_single_prompt': False,\n", + " 'use_tasks_prompts': True,\n", + " 'use_vis_adapter': False,\n", + " 'use_vis_layer_norm': True,\n", + " 'use_vis_order_embedding': True,\n", + " 'use_vision': True,\n", + " 'valid': 'valid',\n", + " 'valid_batch_size': 1,\n", + " 'valid_topk': 4,\n", + " 'vis_adapter_type': 'middle-bottleneck',\n", + " 'vis_lr': 0.0001,\n", + " 'vis_pooling_output': False,\n", + " 'vis_reduction_factor': 2,\n", + " 'vis_weight_decay': 0.01,\n", + " 'warmup_ratio': 0.1,\n", + " 'weight_decay': 0.01,\n", + " 'word_mask_rate': 0.15,\n", + " 'world_size': 1}\n" + ] + } + ], + "source": [ + "# cudnn.benchmark = True\n", + "# args = parse_args(False)\n", + "ngpus_per_node = torch.cuda.device_count()\n", + "args.world_size = ngpus_per_node\n", + "if args.local_rank in [0, -1]:\n", + " print(args)\n", + "\n", + " comments = []\n", + " if args.load is not None:\n", + " ckpt_str = \"_\".join(args.load.split('/')[-3:])\n", + " comments.append(ckpt_str)\n", + " if args.comment != '':\n", + " comments.append(args.comment)\n", + " comment = '_'.join(comments)\n", + "\n", + " from datetime import datetime\n", + " current_time = datetime.now().strftime('%b%d_%H-%M')\n", + " run_name = f'{current_time}_GPU{args.world_size}'\n", + " if len(comments) > 0:\n", + " run_name += f'_{comment}'\n", + "\n", + " if args.run_name == \"\":\n", + " args.run_name = run_name" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "71a508de-c65b-45ce-964e-ae7f138ca8f5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Launching at GPU 0\n", + "args.feature_type RN101\n" + ] + } + ], + "source": [ + "# GPU is assigned\n", + "gpu = args.local_rank\n", + "args.gpu = gpu\n", + "args.rank = gpu\n", + "print(f'Launching at GPU {gpu}')\n", + "\n", + "print(f\"args.feature_type {args.feature_type}\")\n", + "feat_dim_dict = {\n", + " \"RN50\": 2048,\n", + " \"RN101\": 2048,\n", + " \"RN50x4\": 2560,\n", + " \"ViT\": 768\n", + "}\n", + "args.feat_dim = feat_dim_dict[args.feature_type]\n", + "\n", + "vqa_args = deepcopy(args)\n", + "vqa_args.max_text_length = 20\n", + "\n", + "\n", + "if args.use_tasks_prompts:\n", + " vqa_args.prompt = \"vqa: \"\n", + "else:\n", + " vqa_args.prompt = \"\"" + ] + }, + { + "cell_type": "markdown", + "id": "5aff4312-8146-45d5-a00e-f668d4df84f2", + "metadata": {}, + "source": [ + "Now we create the data loaders." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "85c17a83-f8ff-4568-99d2-2701ee41387e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Load 605102 data from split(s) karpathy_train.\n", + "# Answers: 3129\n", + "Data sources: ['karpathy_train']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/nlp/scr/peterwz/miniconda3/envs/peterwz-dora/lib/python3.8/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded 605102 data from karpathy_train\n", + "Use only 10000 data\n", + "# all sentences: 10000\n", + "Building VQA val loader at GPU 0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/nlp/scr/peterwz/miniconda3/envs/peterwz-dora/lib/python3.8/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", + " warnings.warn(_create_warning_msg(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Load 26729 data from split(s) karpathy_val.\n", + "# Answers: 3129\n", + "Data sources: ['karpathy_val']\n", + "Loaded 26729 data from karpathy_val\n", + "Use only 4 data\n", + "# all sentences: 4\n", + "Building VQA test loader at GPU 0\n", + "Load 26280 data from split(s) karpathy_test.\n", + "# Answers: 3129\n", + "Data sources: ['karpathy_test']\n", + "Loaded 26280 data from karpathy_test\n", + "Use only 4 data\n", + "# all sentences: 4\n" + ] + } + ], + "source": [ + "\n", + "train_loaders = []\n", + "\n", + "train_loader = get_loader(\n", + " vqa_args,\n", + " split='karpathy_train', mode='train', batch_size=vqa_args.batch_size,\n", + " distributed=False, gpu=args.gpu,\n", + " workers=args.num_workers,\n", + " topk=args.train_topk,\n", + ")\n", + "\n", + "val_num_workers = 4\n", + "# Validation set\n", + "if gpu == 0:\n", + " val_loader = {}\n", + " if args.epochs > 0:\n", + " if 'vqa' in args.tasks:\n", + " print(f'Building VQA val loader at GPU {gpu}')\n", + " vqa_val_loader = get_loader(\n", + " vqa_args,\n", + " split='karpathy_val', mode='val', batch_size=vqa_args.valid_batch_size,\n", + " distributed=False, gpu=args.gpu,\n", + " workers=val_num_workers,\n", + " topk=args.valid_topk,\n", + " )\n", + " val_loader['vqa'] = vqa_val_loader\n", + "\n", + " # Test set\n", + " test_loader = {}\n", + " if 'vqa' in args.tasks:\n", + " print(f'Building VQA test loader at GPU {gpu}')\n", + " vqa_test_loader = get_loader(\n", + " vqa_args,\n", + " split='karpathy_test', mode='val', batch_size=vqa_args.valid_batch_size,\n", + " distributed=False, gpu=args.gpu,\n", + " workers=val_num_workers,\n", + " topk=args.valid_topk,\n", + " )\n", + " test_loader['vqa'] = vqa_test_loader\n", + "\n", + " if args.testing:\n", + " vqa_submit_test_loader = get_loader(\n", + " vqa_args,\n", + " split='test_4', mode='val', batch_size=vqa_args.valid_batch_size,\n", + " distributed=False, gpu=args.gpu,\n", + " workers=val_num_workers,\n", + " topk=args.valid_topk,\n", + " )\n", + " test_loader['vqa_submit'] = vqa_submit_test_loader\n", + "else:\n", + " val_loader = None\n", + " test_loader = None\n" + ] + }, + { + "cell_type": "markdown", + "id": "40ce4484-c07b-44e9-afe4-2b32065e05e1", + "metadata": {}, + "source": [ + "Now we start training!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb661683-666b-432d-b77e-9849fc726a6e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "IntervenableConfig\n", + "{\n", + " \"model_type\": \"None\",\n", + " \"representations\": [\n", + " {\n", + " \"layer\": 0,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 4,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 1,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 4,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 2,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 4,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 3,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 4,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 4,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 4,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 5,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 4,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 0,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 4,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 1,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 4,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 2,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 4,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 3,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 4,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 4,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 4,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 5,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 4,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 0,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 64,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 1,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 64,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 2,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 64,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 3,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 64,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 4,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 64,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 5,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 64,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 0,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 64,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 1,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 64,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 2,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 64,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 3,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 64,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 4,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 64,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " },\n", + " {\n", + " \"layer\": 5,\n", + " \"component\": \"block_output\",\n", + " \"unit\": \"pos\",\n", + " \"max_number_of_units\": 1,\n", + " \"low_rank_dimension\": 64,\n", + " \"intervention_type\": null,\n", + " \"intervention\": \"PLACEHOLDER\",\n", + " \"subspace_partition\": null,\n", + " \"group_key\": null,\n", + " \"intervention_link_key\": null,\n", + " \"moe_key\": null,\n", + " \"source_representation\": null,\n", + " \"hidden_source_representation\": null\n", + " }\n", + " ],\n", + " \"intervention_types\": \"[, , , , , , , , , , , , , , , , , , , , , , , ]\",\n", + " \"mode\": \"parallel\",\n", + " \"sorted_keys\": \"None\",\n", + " \"intervention_dimensions\": \"None\"\n", + "}\n", + "Building Model at GPU 0\n", + "trainable intervention params: 1,254,192 || trainable model params: 0\n", + "model params: 140,995,584 || trainable%: 0.8895257315292938\n", + "layer#0#comp#block_output#unit#pos#nunit#1#0#rotate_layer#parametrizations#weight#original\n", + "layer#0#comp#block_output#unit#pos#nunit#1#0#learned_source#weight\n", + "layer#0#comp#block_output#unit#pos#nunit#1#0#learned_source#bias\n", + "layer#1#comp#block_output#unit#pos#nunit#1#0#rotate_layer#parametrizations#weight#original\n", + "layer#1#comp#block_output#unit#pos#nunit#1#0#learned_source#weight\n", + "layer#1#comp#block_output#unit#pos#nunit#1#0#learned_source#bias\n", + "layer#2#comp#block_output#unit#pos#nunit#1#0#rotate_layer#parametrizations#weight#original\n", + "layer#2#comp#block_output#unit#pos#nunit#1#0#learned_source#weight\n", + "layer#2#comp#block_output#unit#pos#nunit#1#0#learned_source#bias\n", + "layer#3#comp#block_output#unit#pos#nunit#1#0#rotate_layer#parametrizations#weight#original\n", + "layer#3#comp#block_output#unit#pos#nunit#1#0#learned_source#weight\n", + "layer#3#comp#block_output#unit#pos#nunit#1#0#learned_source#bias\n", + "layer#4#comp#block_output#unit#pos#nunit#1#0#rotate_layer#parametrizations#weight#original\n", + "layer#4#comp#block_output#unit#pos#nunit#1#0#learned_source#weight\n", + "layer#4#comp#block_output#unit#pos#nunit#1#0#learned_source#bias\n", + "layer#5#comp#block_output#unit#pos#nunit#1#0#rotate_layer#parametrizations#weight#original\n", + "layer#5#comp#block_output#unit#pos#nunit#1#0#learned_source#weight\n", + "layer#5#comp#block_output#unit#pos#nunit#1#0#learned_source#bias\n", + "layer#0#comp#block_output#unit#pos#nunit#1#1#rotate_layer#parametrizations#weight#original\n", + "layer#0#comp#block_output#unit#pos#nunit#1#1#learned_source#weight\n", + "layer#0#comp#block_output#unit#pos#nunit#1#1#learned_source#bias\n", + "layer#1#comp#block_output#unit#pos#nunit#1#1#rotate_layer#parametrizations#weight#original\n", + "layer#1#comp#block_output#unit#pos#nunit#1#1#learned_source#weight\n", + "layer#1#comp#block_output#unit#pos#nunit#1#1#learned_source#bias\n", + "layer#2#comp#block_output#unit#pos#nunit#1#1#rotate_layer#parametrizations#weight#original\n", + "layer#2#comp#block_output#unit#pos#nunit#1#1#learned_source#weight\n", + "layer#2#comp#block_output#unit#pos#nunit#1#1#learned_source#bias\n", + "layer#3#comp#block_output#unit#pos#nunit#1#1#rotate_layer#parametrizations#weight#original\n", + "layer#3#comp#block_output#unit#pos#nunit#1#1#learned_source#weight\n", + "layer#3#comp#block_output#unit#pos#nunit#1#1#learned_source#bias\n", + "layer#4#comp#block_output#unit#pos#nunit#1#1#rotate_layer#parametrizations#weight#original\n", + "layer#4#comp#block_output#unit#pos#nunit#1#1#learned_source#weight\n", + "layer#4#comp#block_output#unit#pos#nunit#1#1#learned_source#bias\n", + "layer#5#comp#block_output#unit#pos#nunit#1#1#rotate_layer#parametrizations#weight#original\n", + "layer#5#comp#block_output#unit#pos#nunit#1#1#learned_source#weight\n", + "layer#5#comp#block_output#unit#pos#nunit#1#1#learned_source#bias\n", + "layer#0#comp#block_output#unit#pos#nunit#1#2#rotate_layer#parametrizations#weight#original\n", + "layer#0#comp#block_output#unit#pos#nunit#1#2#learned_source#weight\n", + "layer#0#comp#block_output#unit#pos#nunit#1#2#learned_source#bias\n", + "layer#1#comp#block_output#unit#pos#nunit#1#2#rotate_layer#parametrizations#weight#original\n", + "layer#1#comp#block_output#unit#pos#nunit#1#2#learned_source#weight\n", + "layer#1#comp#block_output#unit#pos#nunit#1#2#learned_source#bias\n", + "layer#2#comp#block_output#unit#pos#nunit#1#2#rotate_layer#parametrizations#weight#original\n", + "layer#2#comp#block_output#unit#pos#nunit#1#2#learned_source#weight\n", + "layer#2#comp#block_output#unit#pos#nunit#1#2#learned_source#bias\n", + "layer#3#comp#block_output#unit#pos#nunit#1#2#rotate_layer#parametrizations#weight#original\n", + "layer#3#comp#block_output#unit#pos#nunit#1#2#learned_source#weight\n", + "layer#3#comp#block_output#unit#pos#nunit#1#2#learned_source#bias\n", + "layer#4#comp#block_output#unit#pos#nunit#1#2#rotate_layer#parametrizations#weight#original\n", + "layer#4#comp#block_output#unit#pos#nunit#1#2#learned_source#weight\n", + "layer#4#comp#block_output#unit#pos#nunit#1#2#learned_source#bias\n", + "layer#5#comp#block_output#unit#pos#nunit#1#2#rotate_layer#parametrizations#weight#original\n", + "layer#5#comp#block_output#unit#pos#nunit#1#2#learned_source#weight\n", + "layer#5#comp#block_output#unit#pos#nunit#1#2#learned_source#bias\n", + "layer#0#comp#block_output#unit#pos#nunit#1#3#rotate_layer#parametrizations#weight#original\n", + "layer#0#comp#block_output#unit#pos#nunit#1#3#learned_source#weight\n", + "layer#0#comp#block_output#unit#pos#nunit#1#3#learned_source#bias\n", + "layer#1#comp#block_output#unit#pos#nunit#1#3#rotate_layer#parametrizations#weight#original\n", + "layer#1#comp#block_output#unit#pos#nunit#1#3#learned_source#weight\n", + "layer#1#comp#block_output#unit#pos#nunit#1#3#learned_source#bias\n", + "layer#2#comp#block_output#unit#pos#nunit#1#3#rotate_layer#parametrizations#weight#original\n", + "layer#2#comp#block_output#unit#pos#nunit#1#3#learned_source#weight\n", + "layer#2#comp#block_output#unit#pos#nunit#1#3#learned_source#bias\n", + "layer#3#comp#block_output#unit#pos#nunit#1#3#rotate_layer#parametrizations#weight#original\n", + "layer#3#comp#block_output#unit#pos#nunit#1#3#learned_source#weight\n", + "layer#3#comp#block_output#unit#pos#nunit#1#3#learned_source#bias\n", + "layer#4#comp#block_output#unit#pos#nunit#1#3#rotate_layer#parametrizations#weight#original\n", + "layer#4#comp#block_output#unit#pos#nunit#1#3#learned_source#weight\n", + "layer#4#comp#block_output#unit#pos#nunit#1#3#learned_source#bias\n", + "layer#5#comp#block_output#unit#pos#nunit#1#3#rotate_layer#parametrizations#weight#original\n", + "layer#5#comp#block_output#unit#pos#nunit#1#3#learned_source#weight\n", + "layer#5#comp#block_output#unit#pos#nunit#1#3#learned_source#bias\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of VLBartVQA were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['layer#5#comp#block_output#unit#pos#nunit#1#3#learned_source#weight', 'layer#3#comp#block_output#unit#pos#nunit#1#0#rotate_layer#parametrizations#weight#original', 'layer#1#comp#block_output#unit#pos#nunit#1#3#learned_source#bias', 'layer#3#comp#block_output#unit#pos#nunit#1#3#learned_source#weight', 'encoder.visual_embedding.feat_embedding.1.weight', 'layer#1#comp#block_output#unit#pos#nunit#1#2#learned_source#weight', 'layer#2#comp#block_output#unit#pos#nunit#1#1#learned_source#bias', 'layer#5#comp#block_output#unit#pos#nunit#1#2#rotate_layer#parametrizations#weight#original', 'layer#4#comp#block_output#unit#pos#nunit#1#3#learned_source#bias', 'layer#3#comp#block_output#unit#pos#nunit#1#2#rotate_layer#parametrizations#weight#original', 'layer#0#comp#block_output#unit#pos#nunit#1#3#rotate_layer#parametrizations#weight#original', 'layer#5#comp#block_output#unit#pos#nunit#1#1#learned_source#weight', 'layer#4#comp#block_output#unit#pos#nunit#1#2#learned_source#weight', 'layer#1#comp#block_output#unit#pos#nunit#1#3#rotate_layer#parametrizations#weight#original', 'layer#3#comp#block_output#unit#pos#nunit#1#1#learned_source#bias', 'layer#0#comp#block_output#unit#pos#nunit#1#1#learned_source#weight', 'layer#0#comp#block_output#unit#pos#nunit#1#2#learned_source#bias', 'layer#2#comp#block_output#unit#pos#nunit#1#0#learned_source#bias', 'layer#2#comp#block_output#unit#pos#nunit#1#0#rotate_layer#parametrizations#weight#original', 'layer#5#comp#block_output#unit#pos#nunit#1#0#learned_source#weight', 'layer#5#comp#block_output#unit#pos#nunit#1#1#learned_source#bias', 'layer#1#comp#block_output#unit#pos#nunit#1#1#learned_source#bias', 'layer#5#comp#block_output#unit#pos#nunit#1#3#rotate_layer#parametrizations#weight#original', 'layer#5#comp#block_output#unit#pos#nunit#1#0#learned_source#bias', 'layer#0#comp#block_output#unit#pos#nunit#1#3#learned_source#weight', 'layer#3#comp#block_output#unit#pos#nunit#1#1#learned_source#weight', 'layer#1#comp#block_output#unit#pos#nunit#1#0#learned_source#weight', 'layer#4#comp#block_output#unit#pos#nunit#1#0#learned_source#bias', 'layer#2#comp#block_output#unit#pos#nunit#1#2#rotate_layer#parametrizations#weight#original', 'layer#5#comp#block_output#unit#pos#nunit#1#2#learned_source#weight', 'layer#0#comp#block_output#unit#pos#nunit#1#0#learned_source#weight', 'layer#1#comp#block_output#unit#pos#nunit#1#2#learned_source#bias', 'layer#0#comp#block_output#unit#pos#nunit#1#0#rotate_layer#parametrizations#weight#original', 'layer#3#comp#block_output#unit#pos#nunit#1#0#learned_source#weight', 'layer#0#comp#block_output#unit#pos#nunit#1#0#learned_source#bias', 'layer#2#comp#block_output#unit#pos#nunit#1#1#learned_source#weight', 'intervenable.model.encoder.visual_embedding.feat_embedding.1.bias', 'layer#3#comp#block_output#unit#pos#nunit#1#2#learned_source#weight', 'layer#3#comp#block_output#unit#pos#nunit#1#2#learned_source#bias', 'intervenable.model.encoder.visual_embedding.feat_embedding.0.bias', 'layer#4#comp#block_output#unit#pos#nunit#1#3#learned_source#weight', 'layer#2#comp#block_output#unit#pos#nunit#1#2#learned_source#bias', 'layer#1#comp#block_output#unit#pos#nunit#1#1#learned_source#weight', 'layer#1#comp#block_output#unit#pos#nunit#1#0#rotate_layer#parametrizations#weight#original', 'layer#5#comp#block_output#unit#pos#nunit#1#3#learned_source#bias', 'layer#3#comp#block_output#unit#pos#nunit#1#1#rotate_layer#parametrizations#weight#original', 'layer#3#comp#block_output#unit#pos#nunit#1#3#rotate_layer#parametrizations#weight#original', 'layer#2#comp#block_output#unit#pos#nunit#1#0#learned_source#weight', 'intervenable.model.encoder.visual_embedding.feat_embedding.0.weight', 'layer#1#comp#block_output#unit#pos#nunit#1#1#rotate_layer#parametrizations#weight#original', 'encoder.visual_embedding.feat_embedding.0.bias', 'layer#4#comp#block_output#unit#pos#nunit#1#2#rotate_layer#parametrizations#weight#original', 'layer#3#comp#block_output#unit#pos#nunit#1#0#learned_source#bias', 'layer#4#comp#block_output#unit#pos#nunit#1#3#rotate_layer#parametrizations#weight#original', 'intervenable.model.encoder.visual_embedding.feat_embedding.1.weight', 'layer#2#comp#block_output#unit#pos#nunit#1#3#rotate_layer#parametrizations#weight#original', 'layer#2#comp#block_output#unit#pos#nunit#1#2#learned_source#weight', 'layer#3#comp#block_output#unit#pos#nunit#1#3#learned_source#bias', 'layer#4#comp#block_output#unit#pos#nunit#1#2#learned_source#bias', 'layer#2#comp#block_output#unit#pos#nunit#1#3#learned_source#weight', 'layer#4#comp#block_output#unit#pos#nunit#1#1#rotate_layer#parametrizations#weight#original', 'layer#4#comp#block_output#unit#pos#nunit#1#1#learned_source#bias', 'layer#1#comp#block_output#unit#pos#nunit#1#2#rotate_layer#parametrizations#weight#original', 'layer#2#comp#block_output#unit#pos#nunit#1#1#rotate_layer#parametrizations#weight#original', 'layer#4#comp#block_output#unit#pos#nunit#1#1#learned_source#weight', 'layer#0#comp#block_output#unit#pos#nunit#1#1#rotate_layer#parametrizations#weight#original', 'layer#5#comp#block_output#unit#pos#nunit#1#0#rotate_layer#parametrizations#weight#original', 'layer#4#comp#block_output#unit#pos#nunit#1#0#rotate_layer#parametrizations#weight#original', 'layer#0#comp#block_output#unit#pos#nunit#1#1#learned_source#bias', 'layer#4#comp#block_output#unit#pos#nunit#1#0#learned_source#weight', 'layer#1#comp#block_output#unit#pos#nunit#1#3#learned_source#weight', 'layer#0#comp#block_output#unit#pos#nunit#1#2#learned_source#weight', 'encoder.visual_embedding.feat_embedding.1.bias', 'layer#5#comp#block_output#unit#pos#nunit#1#2#learned_source#bias', 'layer#2#comp#block_output#unit#pos#nunit#1#3#learned_source#bias', 'layer#5#comp#block_output#unit#pos#nunit#1#1#rotate_layer#parametrizations#weight#original', 'layer#0#comp#block_output#unit#pos#nunit#1#3#learned_source#bias', 'encoder.visual_embedding.feat_embedding.0.weight', 'layer#0#comp#block_output#unit#pos#nunit#1#2#rotate_layer#parametrizations#weight#original', 'layer#1#comp#block_output#unit#pos#nunit#1#0#learned_source#bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding dimension will be 50465. This might induce some performance reduction as *Tensor Cores* will not be available. For more details about this, or help on choosing the correct value for resizing, refer to this guide: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model Launching at GPU 0\n", + "model.encoder.visual_embedding.feat_embedding.0.weight is trainable...\n", + "model.encoder.visual_embedding.feat_embedding.0.bias is trainable...\n", + "model.encoder.visual_embedding.feat_embedding.1.weight is trainable...\n", + "model.encoder.visual_embedding.feat_embedding.1.bias is trainable...\n", + "layer#0#comp#block_output#unit#pos#nunit#1#0#learned_source#bias is trainable...(4)\n", + "layer#1#comp#block_output#unit#pos#nunit#1#0#learned_source#bias is trainable...(4)\n", + "layer#2#comp#block_output#unit#pos#nunit#1#0#learned_source#bias is trainable...(4)\n", + "layer#3#comp#block_output#unit#pos#nunit#1#0#learned_source#bias is trainable...(4)\n", + "layer#4#comp#block_output#unit#pos#nunit#1#0#learned_source#bias is trainable...(4)\n", + "layer#5#comp#block_output#unit#pos#nunit#1#0#learned_source#bias is trainable...(4)\n", + "layer#0#comp#block_output#unit#pos#nunit#1#1#learned_source#bias is trainable...(4)\n", + "layer#1#comp#block_output#unit#pos#nunit#1#1#learned_source#bias is trainable...(4)\n", + "layer#2#comp#block_output#unit#pos#nunit#1#1#learned_source#bias is trainable...(4)\n", + "layer#3#comp#block_output#unit#pos#nunit#1#1#learned_source#bias is trainable...(4)\n", + "layer#4#comp#block_output#unit#pos#nunit#1#1#learned_source#bias is trainable...(4)\n", + "layer#5#comp#block_output#unit#pos#nunit#1#1#learned_source#bias is trainable...(4)\n", + "layer#0#comp#block_output#unit#pos#nunit#1#2#learned_source#bias is trainable...(64)\n", + "layer#1#comp#block_output#unit#pos#nunit#1#2#learned_source#bias is trainable...(64)\n", + "layer#2#comp#block_output#unit#pos#nunit#1#2#learned_source#bias is trainable...(64)\n", + "layer#3#comp#block_output#unit#pos#nunit#1#2#learned_source#bias is trainable...(64)\n", + "layer#4#comp#block_output#unit#pos#nunit#1#2#learned_source#bias is trainable...(64)\n", + "layer#5#comp#block_output#unit#pos#nunit#1#2#learned_source#bias is trainable...(64)\n", + "layer#0#comp#block_output#unit#pos#nunit#1#3#learned_source#bias is trainable...(64)\n", + "layer#1#comp#block_output#unit#pos#nunit#1#3#learned_source#bias is trainable...(64)\n", + "layer#2#comp#block_output#unit#pos#nunit#1#3#learned_source#bias is trainable...(64)\n", + "layer#3#comp#block_output#unit#pos#nunit#1#3#learned_source#bias is trainable...(64)\n", + "layer#4#comp#block_output#unit#pos#nunit#1#3#learned_source#bias is trainable...(64)\n", + "layer#5#comp#block_output#unit#pos#nunit#1#3#learned_source#bias is trainable...(64)\n", + "model.encoder.layers.0.self_attn.k_proj.bias is trainable...(768)\n", + "model.encoder.layers.0.self_attn.v_proj.bias is trainable...(768)\n", + "model.encoder.layers.0.self_attn.q_proj.bias is trainable...(768)\n", + "model.encoder.layers.0.self_attn.out_proj.bias is trainable...(768)\n", + "model.encoder.layers.0.self_attn_layer_norm.bias is trainable...(768)\n", + "model.encoder.layers.0.fc1.bias is trainable...(3072)\n", + "model.encoder.layers.0.fc2.bias is trainable...(768)\n", + "model.encoder.layers.0.final_layer_norm.bias is trainable...(768)\n", + "model.encoder.layers.1.self_attn.k_proj.bias is trainable...(768)\n", + "model.encoder.layers.1.self_attn.v_proj.bias is trainable...(768)\n", + "model.encoder.layers.1.self_attn.q_proj.bias is trainable...(768)\n", + "model.encoder.layers.1.self_attn.out_proj.bias is trainable...(768)\n", + "model.encoder.layers.1.self_attn_layer_norm.bias is trainable...(768)\n", + "model.encoder.layers.1.fc1.bias is trainable...(3072)\n", + "model.encoder.layers.1.fc2.bias is trainable...(768)\n", + "model.encoder.layers.1.final_layer_norm.bias is trainable...(768)\n", + "model.encoder.layers.2.self_attn.k_proj.bias is trainable...(768)\n", + "model.encoder.layers.2.self_attn.v_proj.bias is trainable...(768)\n", + "model.encoder.layers.2.self_attn.q_proj.bias is trainable...(768)\n", + "model.encoder.layers.2.self_attn.out_proj.bias is trainable...(768)\n", + "model.encoder.layers.2.self_attn_layer_norm.bias is trainable...(768)\n", + "model.encoder.layers.2.fc1.bias is trainable...(3072)\n", + "model.encoder.layers.2.fc2.bias is trainable...(768)\n", + "model.encoder.layers.2.final_layer_norm.bias is trainable...(768)\n", + "model.encoder.layers.3.self_attn.k_proj.bias is trainable...(768)\n", + "model.encoder.layers.3.self_attn.v_proj.bias is trainable...(768)\n", + "model.encoder.layers.3.self_attn.q_proj.bias is trainable...(768)\n", + "model.encoder.layers.3.self_attn.out_proj.bias is trainable...(768)\n", + "model.encoder.layers.3.self_attn_layer_norm.bias is trainable...(768)\n", + "model.encoder.layers.3.fc1.bias is trainable...(3072)\n", + "model.encoder.layers.3.fc2.bias is trainable...(768)\n", + "model.encoder.layers.3.final_layer_norm.bias is trainable...(768)\n", + "model.encoder.layers.4.self_attn.k_proj.bias is trainable...(768)\n", + "model.encoder.layers.4.self_attn.v_proj.bias is trainable...(768)\n", + "model.encoder.layers.4.self_attn.q_proj.bias is trainable...(768)\n", + "model.encoder.layers.4.self_attn.out_proj.bias is trainable...(768)\n", + "model.encoder.layers.4.self_attn_layer_norm.bias is trainable...(768)\n", + "model.encoder.layers.4.fc1.bias is trainable...(3072)\n", + "model.encoder.layers.4.fc2.bias is trainable...(768)\n", + "model.encoder.layers.4.final_layer_norm.bias is trainable...(768)\n", + "model.encoder.layers.5.self_attn.k_proj.bias is trainable...(768)\n", + "model.encoder.layers.5.self_attn.v_proj.bias is trainable...(768)\n", + "model.encoder.layers.5.self_attn.q_proj.bias is trainable...(768)\n", + "model.encoder.layers.5.self_attn.out_proj.bias is trainable...(768)\n", + "model.encoder.layers.5.self_attn_layer_norm.bias is trainable...(768)\n", + "model.encoder.layers.5.fc1.bias is trainable...(3072)\n", + "model.encoder.layers.5.fc2.bias is trainable...(768)\n", + "model.encoder.layers.5.final_layer_norm.bias is trainable...(768)\n", + "model.encoder.layernorm_embedding.bias is trainable...(768)\n", + "model.encoder.visual_embedding.feat_embedding.0.bias is trainable...(768)\n", + "model.encoder.visual_embedding.feat_embedding.1.bias is trainable...(768)\n", + "model.decoder.layers.0.self_attn.k_proj.bias is trainable...(768)\n", + "model.decoder.layers.0.self_attn.v_proj.bias is trainable...(768)\n", + "model.decoder.layers.0.self_attn.q_proj.bias is trainable...(768)\n", + "model.decoder.layers.0.self_attn.out_proj.bias is trainable...(768)\n", + "model.decoder.layers.0.self_attn_layer_norm.bias is trainable...(768)\n", + "model.decoder.layers.0.encoder_attn.k_proj.bias is trainable...(768)\n", + "model.decoder.layers.0.encoder_attn.v_proj.bias is trainable...(768)\n", + "model.decoder.layers.0.encoder_attn.q_proj.bias is trainable...(768)\n", + "model.decoder.layers.0.encoder_attn.out_proj.bias is trainable...(768)\n", + "model.decoder.layers.0.encoder_attn_layer_norm.bias is trainable...(768)\n", + "model.decoder.layers.0.fc1.bias is trainable...(3072)\n", + "model.decoder.layers.0.fc2.bias is trainable...(768)\n", + "model.decoder.layers.0.final_layer_norm.bias is trainable...(768)\n", + "model.decoder.layers.1.self_attn.k_proj.bias is trainable...(768)\n", + "model.decoder.layers.1.self_attn.v_proj.bias is trainable...(768)\n", + "model.decoder.layers.1.self_attn.q_proj.bias is trainable...(768)\n", + "model.decoder.layers.1.self_attn.out_proj.bias is trainable...(768)\n", + "model.decoder.layers.1.self_attn_layer_norm.bias is trainable...(768)\n", + "model.decoder.layers.1.encoder_attn.k_proj.bias is trainable...(768)\n", + "model.decoder.layers.1.encoder_attn.v_proj.bias is trainable...(768)\n", + "model.decoder.layers.1.encoder_attn.q_proj.bias is trainable...(768)\n", + "model.decoder.layers.1.encoder_attn.out_proj.bias is trainable...(768)\n", + "model.decoder.layers.1.encoder_attn_layer_norm.bias is trainable...(768)\n", + "model.decoder.layers.1.fc1.bias is trainable...(3072)\n", + "model.decoder.layers.1.fc2.bias is trainable...(768)\n", + "model.decoder.layers.1.final_layer_norm.bias is trainable...(768)\n", + "model.decoder.layers.2.self_attn.k_proj.bias is trainable...(768)\n", + "model.decoder.layers.2.self_attn.v_proj.bias is trainable...(768)\n", + "model.decoder.layers.2.self_attn.q_proj.bias is trainable...(768)\n", + "model.decoder.layers.2.self_attn.out_proj.bias is trainable...(768)\n", + "model.decoder.layers.2.self_attn_layer_norm.bias is trainable...(768)\n", + "model.decoder.layers.2.encoder_attn.k_proj.bias is trainable...(768)\n", + "model.decoder.layers.2.encoder_attn.v_proj.bias is trainable...(768)\n", + "model.decoder.layers.2.encoder_attn.q_proj.bias is trainable...(768)\n", + "model.decoder.layers.2.encoder_attn.out_proj.bias is trainable...(768)\n", + "model.decoder.layers.2.encoder_attn_layer_norm.bias is trainable...(768)\n", + "model.decoder.layers.2.fc1.bias is trainable...(3072)\n", + "model.decoder.layers.2.fc2.bias is trainable...(768)\n", + "model.decoder.layers.2.final_layer_norm.bias is trainable...(768)\n", + "model.decoder.layers.3.self_attn.k_proj.bias is trainable...(768)\n", + "model.decoder.layers.3.self_attn.v_proj.bias is trainable...(768)\n", + "model.decoder.layers.3.self_attn.q_proj.bias is trainable...(768)\n", + "model.decoder.layers.3.self_attn.out_proj.bias is trainable...(768)\n", + "model.decoder.layers.3.self_attn_layer_norm.bias is trainable...(768)\n", + "model.decoder.layers.3.encoder_attn.k_proj.bias is trainable...(768)\n", + "model.decoder.layers.3.encoder_attn.v_proj.bias is trainable...(768)\n", + "model.decoder.layers.3.encoder_attn.q_proj.bias is trainable...(768)\n", + "model.decoder.layers.3.encoder_attn.out_proj.bias is trainable...(768)\n", + "model.decoder.layers.3.encoder_attn_layer_norm.bias is trainable...(768)\n", + "model.decoder.layers.3.fc1.bias is trainable...(3072)\n", + "model.decoder.layers.3.fc2.bias is trainable...(768)\n", + "model.decoder.layers.3.final_layer_norm.bias is trainable...(768)\n", + "model.decoder.layers.4.self_attn.k_proj.bias is trainable...(768)\n", + "model.decoder.layers.4.self_attn.v_proj.bias is trainable...(768)\n", + "model.decoder.layers.4.self_attn.q_proj.bias is trainable...(768)\n", + "model.decoder.layers.4.self_attn.out_proj.bias is trainable...(768)\n", + "model.decoder.layers.4.self_attn_layer_norm.bias is trainable...(768)\n", + "model.decoder.layers.4.encoder_attn.k_proj.bias is trainable...(768)\n", + "model.decoder.layers.4.encoder_attn.v_proj.bias is trainable...(768)\n", + "model.decoder.layers.4.encoder_attn.q_proj.bias is trainable...(768)\n", + "model.decoder.layers.4.encoder_attn.out_proj.bias is trainable...(768)\n", + "model.decoder.layers.4.encoder_attn_layer_norm.bias is trainable...(768)\n", + "model.decoder.layers.4.fc1.bias is trainable...(3072)\n", + "model.decoder.layers.4.fc2.bias is trainable...(768)\n", + "model.decoder.layers.4.final_layer_norm.bias is trainable...(768)\n", + "model.decoder.layers.5.self_attn.k_proj.bias is trainable...(768)\n", + "model.decoder.layers.5.self_attn.v_proj.bias is trainable...(768)\n", + "model.decoder.layers.5.self_attn.q_proj.bias is trainable...(768)\n", + "model.decoder.layers.5.self_attn.out_proj.bias is trainable...(768)\n", + "model.decoder.layers.5.self_attn_layer_norm.bias is trainable...(768)\n", + "model.decoder.layers.5.encoder_attn.k_proj.bias is trainable...(768)\n", + "model.decoder.layers.5.encoder_attn.v_proj.bias is trainable...(768)\n", + "model.decoder.layers.5.encoder_attn.q_proj.bias is trainable...(768)\n", + "model.decoder.layers.5.encoder_attn.out_proj.bias is trainable...(768)\n", + "model.decoder.layers.5.encoder_attn_layer_norm.bias is trainable...(768)\n", + "model.decoder.layers.5.fc1.bias is trainable...(3072)\n", + "model.decoder.layers.5.fc2.bias is trainable...(768)\n", + "model.decoder.layers.5.final_layer_norm.bias is trainable...(768)\n", + "model.decoder.layernorm_embedding.bias is trainable...(768)\n", + "model.encoder.layers.0.self_attn_layer_norm is trainable...\n", + "model.encoder.layers.0.final_layer_norm is trainable...\n", + "model.encoder.layers.1.self_attn_layer_norm is trainable...\n", + "model.encoder.layers.1.final_layer_norm is trainable...\n", + "model.encoder.layers.2.self_attn_layer_norm is trainable...\n", + "model.encoder.layers.2.final_layer_norm is trainable...\n", + "model.encoder.layers.3.self_attn_layer_norm is trainable...\n", + "model.encoder.layers.3.final_layer_norm is trainable...\n", + "model.encoder.layers.4.self_attn_layer_norm is trainable...\n", + "model.encoder.layers.4.final_layer_norm is trainable...\n", + "model.encoder.layers.5.self_attn_layer_norm is trainable...\n", + "model.encoder.layers.5.final_layer_norm is trainable...\n", + "model.encoder.layernorm_embedding is trainable...\n", + "model.encoder.visual_embedding.feat_embedding.1 is trainable...\n", + "model.decoder.layers.0.self_attn_layer_norm is trainable...\n", + "model.decoder.layers.0.encoder_attn_layer_norm is trainable...\n", + "model.decoder.layers.0.final_layer_norm is trainable...\n", + "model.decoder.layers.1.self_attn_layer_norm is trainable...\n", + "model.decoder.layers.1.encoder_attn_layer_norm is trainable...\n", + "model.decoder.layers.1.final_layer_norm is trainable...\n", + "model.decoder.layers.2.self_attn_layer_norm is trainable...\n", + "model.decoder.layers.2.encoder_attn_layer_norm is trainable...\n", + "model.decoder.layers.2.final_layer_norm is trainable...\n", + "model.decoder.layers.3.self_attn_layer_norm is trainable...\n", + "model.decoder.layers.3.encoder_attn_layer_norm is trainable...\n", + "model.decoder.layers.3.final_layer_norm is trainable...\n", + "model.decoder.layers.4.self_attn_layer_norm is trainable...\n", + "model.decoder.layers.4.encoder_attn_layer_norm is trainable...\n", + "model.decoder.layers.4.final_layer_norm is trainable...\n", + "model.decoder.layers.5.self_attn_layer_norm is trainable...\n", + "model.decoder.layers.5.encoder_attn_layer_norm is trainable...\n", + "model.decoder.layers.5.final_layer_norm is trainable...\n", + "model.decoder.layernorm_embedding is trainable...\n", + "VLBartVQA(\n", + " (model): VLBartModel(\n", + " (shared): Embedding(50465, 768)\n", + " (encoder): JointEncoder(\n", + " (embed_tokens): Embedding(50465, 768)\n", + " (embed_positions): BartLearnedPositionalEmbedding(1026, 768, padding_idx=1)\n", + " (layers): ModuleList(\n", + " (0-5): 6 x BartEncoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (activation_fn): GELUActivation()\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (visual_embedding): VisualEmbedding(\n", + " (feat_embedding): Sequential(\n", + " (0): Linear(in_features=2048, out_features=768, bias=True)\n", + " (1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (downsample): Downsample(\n", + " (pool): AdaptiveMaxPool2d(output_size=(6, 6))\n", + " )\n", + " )\n", + " (decoder): BartDecoder(\n", + " (embed_tokens): Embedding(50465, 768)\n", + " (embed_positions): BartLearnedPositionalEmbedding(1026, 768, padding_idx=1)\n", + " (layers): ModuleList(\n", + " (0-5): 6 x BartDecoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (activation_fn): GELUActivation()\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (lm_head): Linear(in_features=768, out_features=50465, bias=False)\n", + " (intervenable): ReftModel(\n", + " (model): VLBartModel(\n", + " (shared): Embedding(50465, 768)\n", + " (encoder): JointEncoder(\n", + " (embed_tokens): Embedding(50465, 768)\n", + " (embed_positions): BartLearnedPositionalEmbedding(1026, 768, padding_idx=1)\n", + " (layers): ModuleList(\n", + " (0-5): 6 x BartEncoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (activation_fn): GELUActivation()\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (visual_embedding): VisualEmbedding(\n", + " (feat_embedding): Sequential(\n", + " (0): Linear(in_features=2048, out_features=768, bias=True)\n", + " (1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (downsample): Downsample(\n", + " (pool): AdaptiveMaxPool2d(output_size=(6, 6))\n", + " )\n", + " )\n", + " (decoder): BartDecoder(\n", + " (embed_tokens): Embedding(50465, 768)\n", + " (embed_positions): BartLearnedPositionalEmbedding(1026, 768, padding_idx=1)\n", + " (layers): ModuleList(\n", + " (0-5): 6 x BartDecoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (activation_fn): GELUActivation()\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " )\n", + " (bce_loss): BCEWithLogitsLoss()\n", + ")\n", + "Trainable param percentage: 2.09% (2979888/142403376)\n", + "Building Optimizer\n", + "Batch per epoch: 100\n", + "Total Iters: 300\n", + "Warmup ratio: 0.1\n", + "Warm up Iters: 30\n", + "It took 0.8s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/nlp/scr/peterwz/miniconda3/envs/peterwz-dora/lib/python3.8/site-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", + " warnings.warn(\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mpeterzw494\u001b[0m (\u001b[33mpeterwz\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.17.4" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /sailhome/peterwz/workspace/pyreft/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/wandb/run-20240708_171816-ld0th829" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run misunderstood-planet-109 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/peterwz/Test" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/peterwz/Test/runs/ld0th829" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Symlinked 0 file into the W&B run directory, call wandb.save again to sync new files.\n", + " 0%| | 0/100 [00:001` or unset `early_stopping`.\n", + " warnings.warn(\n", + "/nlp/scr/peterwz/miniconda3/envs/peterwz-dora/lib/python3.8/site-packages/transformers/generation/utils.py:1260: UserWarning: Using the model-agnostic default `max_length` (=20) to control thegeneration length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n", + " warnings.warn(\n", + "VQA Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00, 1.43it/s]\n", + "100%|███████████████████████████████████████████| 4/4 [00:00<00:00, 5751.53it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "VQA\n", + "Epoch 0: Valid Raw 47.50 Topk 47.50\n", + "Epoch 0: Best Raw 47.50\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/100 [00:00 int: + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def is_main_process() -> bool: + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def _serialize_to_tensor(data, group): + backend = dist.get_backend(group) + assert backend in ["gloo", "nccl"] + device = torch.device("cpu" if backend == "gloo" else "cuda") + + buffer = pickle.dumps(data) + if len(buffer) > 1024 ** 3: + logger = logging.getLogger(__name__) + logger.warning( + "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( + get_rank(), len(buffer) / (1024 ** 3), device + ) + ) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + + +def _pad_to_largest_tensor(tensor, group): + """ + Returns: + list[int]: size of the tensor, on each rank + Tensor: padded tensor that has the max size + """ + world_size = dist.get_world_size(group=group) + assert ( + world_size >= 1 + ), "comm.gather/all_gather must be called from ranks within the given group!" + local_size = torch.tensor( + [tensor.numel()], dtype=torch.int64, device=tensor.device) + size_list = [ + torch.zeros([1], dtype=torch.int64, device=tensor.device) + for _ in range(world_size) + ] + dist.all_gather(size_list, local_size, group=group) + size_list = [int(size.item()) for size in size_list] + + max_size = max(size_list) + + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + if local_size != max_size: + padding = torch.zeros( + (max_size - local_size,), dtype=torch.uint8, device=tensor.device + ) + tensor = torch.cat((tensor, padding), dim=0) + return size_list, tensor + + +def all_gather(data, group=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group) == 1: + return [data] + + tensor = _serialize_to_tensor(data, group) + + size_list, tensor = _pad_to_largest_tensor(tensor, group) + max_size = max(size_list) + + # receiving Tensor from all ranks + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) + for _ in size_list + ] + dist.all_gather(tensor_list, tensor, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def gather(data, dst=0, group=None): + """ + Run gather on arbitrary picklable data (not necessarily tensors). + Args: + data: any picklable object + dst (int): destination rank + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + Returns: + list[data]: on dst, a list of data gathered from each rank. Otherwise, + an empty list. + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group=group) == 1: + return [data] + rank = dist.get_rank(group=group) + + tensor = _serialize_to_tensor(data, group) + size_list, tensor = _pad_to_largest_tensor(tensor, group) + + # receiving Tensor from all ranks + if rank == dst: + max_size = max(size_list) + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) + for _ in size_list + ] + dist.gather(tensor, tensor_list, dst=dst, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + return data_list + else: + dist.gather(tensor, [], dst=dst, group=group) + return [] + + +def shared_random_seed(): + """ + Returns: + int: a random number that is the same across all workers. + If workers need a shared RNG, they can use this shared seed to + create one. + All workers must call this function, otherwise it will deadlock. + """ + ints = np.random.randint(2 ** 31) + all_ints = all_gather(ints) + return all_ints[0] + + +# def reduce_dict(input_dict, average=True): +# """ +# Reduce the values in the dictionary from all processes so that process with rank +# 0 has the reduced results. +# Args: +# input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. +# average (bool): whether to do average or sum +# Returns: +# a dict with the same keys as input_dict, after reduction. +# """ +# world_size = get_world_size() +# if world_size < 2: +# return input_dict +# with torch.no_grad(): +# names = [] +# values = [] +# # sort the keys so that they are consistent across processes +# for k in sorted(input_dict.keys()): +# names.append(k) +# values.append(input_dict[k]) +# values = torch.stack(values, dim=0) +# dist.reduce(values, dst=0) +# if dist.get_rank() == 0 and average: +# # only main process gets accumulated, so only divide by +# # world_size in this case +# values /= world_size +# reduced_dict = {k: v for k, v in zip(names, values)} +# return reduced_dict + + +def reduce_dict(input_dict, average=True): + """ + Reduce the values in the dictionary from all processes so that process with rank + 0 has the reduced results. + Args: + input_dict (dict): inputs to be reduced. (values not necessarily tensors). + average (bool): whether to do average or sum + Returns: + a dict with the same keys as input_dict, after reduction. + """ + + world_size = get_world_size() + if world_size < 2: + return input_dict + + with torch.no_grad(): + + # Convert to CUDA Tensor for dist.reduce() + input_dict_cuda_vals = {} + for k, v in input_dict.items(): + if type(v) == torch.Tensor: + input_dict_cuda_vals[k] = v.to('cuda') + else: + input_dict_cuda_vals[k] = torch.tensor(v, device='cuda') + + names = [] + values = [] + for k, v in sorted(input_dict_cuda_vals.items()): + names.append(k) + values.append(v) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) # reduce to gpu 0 + + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/modeling_bart.py b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/modeling_bart.py new file mode 100644 index 0000000..6b56a09 --- /dev/null +++ b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/modeling_bart.py @@ -0,0 +1,542 @@ + +import math +import random +from dataclasses import dataclass + +import transformers + +from transformers.models.bart.modeling_bart import ( + BartConfig, + shift_tokens_right, + _expand_mask +) + +from my_transformers.modeling_bart import BartModel, BartForConditionalGeneration, BartDecoder, BartEncoder + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss + +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +import copy + +from transformers.modeling_outputs import ModelOutput, BaseModelOutput, Seq2SeqLMOutput, Seq2SeqModelOutput +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +class VisualEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + feat_dim = config.feat_dim + pos_dim = config.pos_dim + # n_objs = config.n_objs + n_images = config.n_images + + # Object feature encoding + feat_embedding = [nn.Linear(feat_dim, config.d_model)] + + # use custom layer norm + if self.config.use_vis_layer_norm and self.config.individual_vis_layer_norm: + feat_embedding.append(nn.LayerNorm(config.d_model)) + + self.feat_embedding = nn.Sequential(*feat_embedding) + + # use one layer norm + if self.config.use_vis_layer_norm and not self.config.individual_vis_layer_norm: + self.layer_norm = nn.LayerNorm(config.d_model) + + def forward(self, feats, pos): + """ + Args + feats: [B, N, feat_dim] + pos: [B, N, 4] + (x1, x2, y1, y2) + Return + relative_vis_pos_embedding: [B, N, N, n_heads] + """ + + B, N, _ = feats.size() + assert pos.size() == (B, N, 4) + + feat_embedding = self.feat_embedding(feats) + + device = feats.device + dtype = feats.dtype + vis_embedding = feat_embedding + + if not self.config.individual_vis_layer_norm: + if self.config.use_vis_layer_norm: + vis_embedding = self.layer_norm(vis_embedding) + + return vis_embedding + + +class Downsample(nn.Module): + def __init__(self, output_size): + super().__init__() + """ + output size: list of 1-D size, such as (6, 6), (16, 16) + """ + self.output_size = output_size + self.pool = nn.AdaptiveMaxPool2d(output_size) + + def downsample_inputs(self, inputs): + B, L, dim = inputs.shape + + inputs = inputs.permute(0, 2, 1) # (2B, dim, L/2) + + # restriction: L**0.5 must to be integer + sqrt_L = int(L ** 0.5) + + inputs = inputs.reshape(B, dim, sqrt_L, sqrt_L) + + inputs = self.pool(inputs) # (B, dim, self.output_size[0], self.output_size[1]) + inputs = inputs.reshape(B, dim, -1) + + inputs = inputs.permute(0, 2, 1) # (2B, self.output_size[0]**2, dim) + + return inputs + + def forward(self, inputs_tuple): + # inputs (B, L, dim) + inputs, boxes = inputs_tuple + + inputs = self.downsample_inputs(inputs) + boxes = boxes[:, :inputs.shape[1]] # Get the first few data because the element are all zeros + + outputs_tuple = (inputs, boxes) + + return outputs_tuple + + +class JointEncoder(BartEncoder): + """ + BartEncoder + visual embedding + """ + def __init__(self, config, embed_tokens=None, task_embed=None): + super().__init__(config, embed_tokens, task_embed) + + self.config = config + + self.visual_embedding = VisualEmbedding(config) + + self.downsample = None + sqrt_size = int(config.n_boxes ** 0.5) + output_size = (sqrt_size, sqrt_size) + self.downsample = Downsample(output_size) + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + + vis_inputs=None, + vis_attention_mask=None, + + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + past_key_values=None, + return_dict=None, + task=None + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(input_shape) + + inputs_embeds = inputs_embeds + embed_pos + + B, L = inputs_embeds.size()[:-1] + + if vis_inputs is not None: + + if self.downsample is not None: + vis_inputs = self.downsample(vis_inputs) + + vis_feats = vis_inputs[0] + boxes = vis_inputs[1] + + vis_embeds = self.visual_embedding(vis_feats, boxes) + + V_L = vis_embeds.size(1) + # print("Input IDs", input_ids.shape, "Embed len", inputs_embeds.shape, "Vis embed len", vis_embeds.shape) + if self.config.share_vis_lang_layer_norm: + inputs_embeds = torch.cat([inputs_embeds, vis_embeds], dim=1) + + inputs_embeds = self.layernorm_embedding(inputs_embeds) + else: + inputs_embeds = self.layernorm_embedding(inputs_embeds) + inputs_embeds = torch.cat([inputs_embeds, vis_embeds], dim=1) + + if vis_attention_mask is None: + vis_attention_mask = torch.ones(B, V_L, dtype=inputs_embeds.dtype, device=inputs_embeds.device) + + else: + inputs_embeds = self.layernorm_embedding(inputs_embeds) + + if attention_mask is None: + attention_mask = input_ids.ne(self.config.pad_token_id).to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) + + hidden_states = F.dropout(inputs_embeds, p=self.dropout, training=self.training) + + if vis_attention_mask is not None: + attention_mask = torch.cat([attention_mask, vis_attention_mask], dim=1) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + task_embedding = None + if task is not None and self.task_embed is not None: + task_embedding = self.task_embed(task) + + # print('ext_attention_mask, ', attention_mask.size()) + # print('attention_mask') + # print(attention_mask.size()) + # print(attention_mask) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + + # for prefix + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer(hidden_states, attention_mask, past_key_value, None, task=task, output_attentions=output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + +class VLBartModel(BartModel): + def __init__(self, config: BartConfig): + super(BartModel, self).__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + if config.use_hyperformer: + self.shared_task_embed = TaskEmbeddingController(config.adapter_config) + else: + self.shared_task_embed = None + + #----- Modified-----# + # self.encoder = BartEncoder(config, self.shared) + + self.encoder = JointEncoder(config, self.shared, self.shared_task_embed) + #-------------------# + self.decoder = BartDecoder(config, self.shared, self.shared_task_embed) + + self.config = config + + self.init_weights() + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def forward( + self, + input_ids=None, + attention_mask=None, + + vis_inputs=None, + vis_attention_mask=None, + + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task=None, + **kwargs, + ): + # different to other models, Bart automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + + vis_inputs=vis_inputs, + vis_attention_mask=vis_attention_mask, + + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + task=task, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=False + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + if attention_mask is None: + attention_mask = input_ids.ne(self.config.pad_token_id).to(dtype=torch.float, device=input_ids.device) + + if vis_attention_mask is None: + B, L = attention_mask.size() + V_L = encoder_outputs[0].size(1) - L + vis_attention_mask = attention_mask.new_ones(B, V_L) + + encoder_attention_mask = torch.cat([attention_mask, vis_attention_mask], dim=1) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + # encoder_attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + task=task, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class VLBart(BartForConditionalGeneration): + base_model_prefix = "model" + _keys_to_ignore_on_load_missing = [ + r"final_logits_bias", + r"encoder\.version", + r"decoder\.version", + r"lm_head\.weight", + ] + + def __init__(self, config: BartConfig): + super(BartForConditionalGeneration, self).__init__(config) + self.model = VLBartModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + + vis_inputs=None, + vis_attention_mask=None, + + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task=None, + + reduce_loss=False, + + **kwargs, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + + vis_inputs=vis_inputs, + vis_attention_mask=vis_attention_mask, + + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + task=task, + ) + + lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias + + masked_lm_loss = None + if labels is not None: + # loss_fct = CrossEntropyLoss() + if reduce_loss: + loss_fct = CrossEntropyLoss(ignore_index=-100) + else: + loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none') + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + # if masked_lm_loss is not None and len(masked_lm_loss) > 1: + # masked_lm_loss = masked_lm_loss[0] + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def vis_forward(self, batch, device): + return batch + + def prepare_inputs_for_generation( + self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + ): + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + output = { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + if "vis_attention_mask" in kwargs: + output["vis_attention_mask"] = kwargs['vis_attention_mask'] + + if "task" in kwargs: + output["task"] = kwargs["task"] + + return output + + @staticmethod + def _expand_inputs_for_generation( + input_ids: torch.LongTensor, + expand_size: int = 1, + is_encoder_decoder: bool = False, + attention_mask: torch.LongTensor = None, + encoder_outputs: ModelOutput = None, + **model_kwargs + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + expanded_return_idx = ( + torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, + expand_size).view(-1).to(input_ids.device) + ) + input_ids = input_ids.index_select(0, expanded_return_idx) + + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = token_type_ids.index_select( + 0, expanded_return_idx) + + if attention_mask is not None: + model_kwargs["attention_mask"] = attention_mask.index_select( + 0, expanded_return_idx) + + if model_kwargs.get("vis_attention_mask", None) is not None: + model_kwargs['vis_attention_mask'] = model_kwargs['vis_attention_mask'].index_select( + 0, expanded_return_idx) + + if is_encoder_decoder: + assert encoder_outputs is not None + encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( + 0, expanded_return_idx + ) + model_kwargs["encoder_outputs"] = encoder_outputs + + return input_ids, model_kwargs diff --git a/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/my_transformers/__init__.py b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/my_transformers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/my_transformers/modeling_bart.py b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/my_transformers/modeling_bart.py new file mode 100644 index 0000000..4c47779 --- /dev/null +++ b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/my_transformers/modeling_bart.py @@ -0,0 +1,1509 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch BART model. """ + + +import math +import random +import warnings +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +# from transformers.file_utils import ( +# add_code_sample_docstrings, +# add_end_docstrings, +# add_start_docstrings, +# add_start_docstrings_to_model_forward, +# replace_return_docstrings, +# ) +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.models.bart.configuration_bart import BartConfig + +import os +import sys +sys.path.insert(1, os.path.join(sys.path[0], '..')) + + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "BartConfig" +_TOKENIZER_FOR_DOC = "BartTokenizer" + + +BART_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/bart-large", + # See all BART models at https://huggingface.co/models?filter=bart +] + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + + +class BartLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): + assert padding_idx is not None, "`padding_idx` should not be None, but of type int" + # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models dont have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim, padding_idx=padding_idx) + + def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + bsz, seq_len = input_ids_shape[:2] + positions = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(positions + self.offset) + + +class BartAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." + self.scaling = self.head_dim ** -0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = hidden_states.size() + + # import pdb + # pdb.set_trace() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + # print(key_states.shape) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + assert attn_weights.size() == ( + bsz * self.num_heads, + tgt_len, + src_len, + ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + + if attention_mask is not None: + assert attention_mask.size() == ( + bsz, + 1, + tgt_len, + src_len, + ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = F.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + assert attn_output.size() == ( + bsz * self.num_heads, + tgt_len, + self.head_dim, + ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + + attn_output = ( + attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + .transpose(1, 2) + .reshape(bsz, tgt_len, embed_dim) + ) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + + +class BartEncoderLayer(nn.Module): + def __init__(self, config: BartConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = BartAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, past_key_value: tuple = None, block_adapters=None, task=None, output_attentions: bool = False): + """ + Args: + hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (:obj:`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, past_key_value=past_key_value, output_attentions=output_attentions + ) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class BartDecoderLayer(nn.Module): + def __init__(self, config: BartConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BartAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + self.encoder_attn = BartAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + block_adapters=None, + task=None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ): + """ + Args: + hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (:obj:`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = None + if past_key_value is not None and len(past_key_value) == 4: # len(past_key_value) is for prefix + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class BartClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class BartPretrainedModel(PreTrainedModel): + config_class = BartConfig + base_model_prefix = "model" + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +class PretrainedBartModel(BartPretrainedModel): + def __init_subclass__(self): + warnings.warn( + "The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.", + FutureWarning, + ) + + +# BART_START_DOCSTRING = r""" +# This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic +# methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, +# pruning heads etc.) +# This model is also a PyTorch `torch.nn.Module `__ +# subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to +# general usage and behavior. +# Parameters: +# config (:class:`~transformers.BartConfig`): +# Model configuration class with all the parameters of the model. Initializing with a config file does not +# load the weights associated with the model, only the configuration. Check out the +# :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +# """ + +# BART_GENERATION_EXAMPLE = r""" +# Summarization example:: +# >>> from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig +# >>> model = BartForConditionalGeneration.from_pretrained('facebook/bart-large') +# >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') +# >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." +# >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt') +# >>> # Generate Summary +# >>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True) +# >>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]) +# Mask filling example:: +# >>> from transformers import BartTokenizer, BartForConditionalGeneration +# >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') +# >>> TXT = "My friends are but they eat too many carbs." +# >>> model = BartForConditionalGeneration.from_pretrained('facebook/bart-large') +# >>> input_ids = tokenizer([TXT], return_tensors='pt')['input_ids'] +# >>> logits = model(input_ids).logits +# >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() +# >>> probs = logits[0, masked_index].softmax(dim=0) +# >>> values, predictions = probs.topk(5) +# >>> tokenizer.decode(predictions).split() +# """ + +# BART_INPUTS_DOCSTRING = r""" +# Args: +# input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): +# Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide +# it. +# Indices can be obtained using :class:`~transformers.BartTokenizer`. See +# :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for +# details. +# `What are input IDs? <../glossary.html#input-ids>`__ +# attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): +# Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: +# - 1 for tokens that are **not masked**, +# - 0 for tokens that are **masked**. +# `What are attention masks? <../glossary.html#attention-mask>`__ +# decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): +# Indices of decoder input sequence tokens in the vocabulary. +# Indices can be obtained using :class:`~transformers.BartTokenizer`. See +# :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for +# details. +# `What are input IDs? <../glossary.html#input-ids>`__ +# Bart uses the :obj:`eos_token_id` as the starting token for :obj:`decoder_input_ids` generation. If +# :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see +# :obj:`past_key_values`). +# For translation and summarization training, :obj:`decoder_input_ids` should be provided. If no +# :obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to +# the right for denoising pre-training following the paper. +# decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): +# Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will +# also be used by default. +# If you want to change padding behavior, you should read :func:`modeling_bart._prepare_decoder_inputs` and +# modify to your needs. See diagram 1 in `the paper `__ for more +# information on the default strategy. +# encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): +# Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: +# :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, +# `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the +# cross-attention of the decoder. +# past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): +# Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. +# If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` +# (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` +# instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. +# inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): +# Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. +# This is useful if you want more control over how to convert :obj:`input_ids` indices into associated +# vectors than the model's internal embedding lookup matrix. +# decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): +# Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded +# representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` +# have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert +# :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. +# If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` +# takes the value of :obj:`inputs_embeds`. +# use_cache (:obj:`bool`, `optional`): +# If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up +# decoding (see :obj:`past_key_values`). +# output_attentions (:obj:`bool`, `optional`): +# Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned +# tensors for more detail. +# output_hidden_states (:obj:`bool`, `optional`): +# Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for +# more detail. +# return_dict (:obj:`bool`, `optional`): +# Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +# """ + + +class BartEncoder(BartPretrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + :class:`BartEncoderLayer`. + Args: + config: BartConfig + embed_tokens (torch.nn.Embedding): output embedding + """ + + def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None, task_embed=None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + + self.task_embed = task_embed + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + self.padding_idx, + ) + self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + past_key_values=None, + return_dict=None, + task=None + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + Indices can be obtained using :class:`~transformers.BartTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + `What are attention masks? <../glossary.html#attention-mask>`__ + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert :obj:`input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(input_shape) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + task_embedding = None + if task is not None and self.task_embed is not None: + task_embedding = self.task_embed(task) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + if getattr(self.config, "gradient_checkpointing", False): + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + past_key_value, + None, # block_adapters + task, + ) + else: + layer_outputs = encoder_layer(hidden_states, attention_mask, past_key_value, None, task, output_attentions=output_attentions) # block_adapters + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class BartDecoder(BartPretrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`BartDecoderLayer` + Args: + config: BartConfig + embed_tokens (torch.nn.Embedding): output embedding + """ + + def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None, task_embed=None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + self.task_embed = task_embed + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + self.padding_idx, + ) + self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + Indices can be obtained using :class:`~transformers.BartTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + `What are attention masks? <../glossary.html#attention-mask>`__ + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last + :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of + shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, + sequence_length)`. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert :obj:`input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + ).to(self.device) + + if attention_mask is not None and combined_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = combined_attention_mask + _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + + # embed positions + positions = self.embed_positions(input_shape, past_key_values_length) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + task_embedding = None + if task is not None and self.task_embed is not None: + task_embedding = self.task_embed(task) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if output_attentions else None + next_decoder_cache = () if use_cache else None + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + block_adapters = None + if getattr(self.config, "gradient_checkpointing", False): + if use_cache: + raise ValueError( + "When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." + ) + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + combined_attention_mask, + encoder_hidden_states, + encoder_attention_mask, + None, + block_adapters, + task, + ) + else: + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + block_adapters=block_adapters, + task=task, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +# @add_start_docstrings( +# "The bare BART Model outputting raw hidden-states without any specific head on top.", +# BART_START_DOCSTRING, +# ) +class BartModel(BartPretrainedModel): + def __init__(self, config: BartConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + self.shared_task_embed = None + + self.encoder = BartEncoder(config, self.shared, self.shared_task_embed) + self.decoder = BartDecoder(config, self.shared, self.shared_task_embed) + + self.init_weights() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + # @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + # @add_code_sample_docstrings( + # tokenizer_class=_TOKENIZER_FOR_DOC, + # checkpoint="facebook/bart-large", + # output_type=Seq2SeqModelOutput, + # # config_class=_CONFIG_FOR_DOC, + # ) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + + # different to other models, Bart automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +# @add_start_docstrings( +# "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING +# ) +class BartForConditionalGeneration(BartPretrainedModel): + base_model_prefix = "model" + _keys_to_ignore_on_load_missing = [ + r"final_logits_bias", + r"encoder\.version", + r"decoder\.version", + r"lm_head\.weight", + ] + + def __init__(self, config: BartConfig): + super().__init__(config) + self.model = BartModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + self.init_weights() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens) + self._resize_final_logits_bias(new_num_tokens) + + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + # @replace_return_docstrings( + # output_type=Seq2SeqLMOutput, + # # config_class=_CONFIG_FOR_DOC, + # ) + # @add_end_docstrings(BART_GENERATION_EXAMPLE) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., + config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``. + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_inputs_for_generation( + self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + ): + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + output = { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + if "task" in kwargs: + output["task"] = kwargs["task"] + + return output + + def adjust_logits_during_generation(self, logits, cur_len, max_length): + if cur_len == 1 and self.config.force_bos_token_to_be_generated: + self._force_token_id_to_be_generated(logits, self.config.bos_token_id) + elif cur_len == max_length - 1 and self.config.eos_token_id is not None: + self._force_token_id_to_be_generated(logits, self.config.eos_token_id) + return logits + + @staticmethod + def _force_token_id_to_be_generated(scores, token_id) -> None: + """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" + scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf") + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +# @add_start_docstrings( +# """ +# Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE +# tasks. +# """, +# BART_START_DOCSTRING, +# ) +class BartForSequenceClassification(BartPretrainedModel): + def __init__(self, config: BartConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = BartModel(config) + self.classification_head = BartClassificationHead( + config.d_model, + config.d_model, + config.num_labels, + config.classifier_dropout, + ) + self.model._init_weights(self.classification_head.dense) + self.model._init_weights(self.classification_head.out_proj) + + # @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + # @add_code_sample_docstrings( + # tokenizer_class=_TOKENIZER_FOR_DOC, + # checkpoint="facebook/bart-large", + # output_type=Seq2SeqSequenceClassifierOutput, + # # config_class=_CONFIG_FOR_DOC, + # ) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] # last hidden state + + eos_mask = input_ids.eq(self.config.eos_token_id) + + if len(torch.unique(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ + :, -1, : + ] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +# @add_start_docstrings( +# """ +# BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear +# layer on top of the hidden-states output to compute `span start logits` and `span end logits`). +# """, +# BART_START_DOCSTRING, +# ) +class BartForQuestionAnswering(BartPretrainedModel): + def __init__(self, config): + super().__init__(config) + + config.num_labels = 2 + self.num_labels = config.num_labels + + self.model = BartModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.model._init_weights(self.qa_outputs) + + # @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + # @add_code_sample_docstrings( + # tokenizer_class=_TOKENIZER_FOR_DOC, + # checkpoint="facebook/bart-large", + # output_type=Seq2SeqQuestionAnsweringModelOutput, + # # config_class=_CONFIG_FOR_DOC, + # ) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + start_positions=None, + end_positions=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if start_positions is not None and end_positions is not None: + use_cache = False + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = ( + start_logits, + end_logits, + ) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return Seq2SeqQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +if __name__ == "__main__": + import transformers + + model = transformers.BartForConditionalGeneration.from_pretrained("facebook/bart-base") + tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/bart-base") + + def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + print(count_parameters(model)) + + inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt") + generation_output = model.generate(**inputs) + + print(generation_output) + + print(tokenizer.batch_decode(generation_output, skip_special_tokens=True)) \ No newline at end of file diff --git a/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/param.py b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/param.py new file mode 100644 index 0000000..3b4c4ee --- /dev/null +++ b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/param.py @@ -0,0 +1,327 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import argparse +import random + +import numpy as np +import torch + +import pprint +import yaml + + +feature_types = { + "RN50", "RN101", "RN50x4", "ViT", "butd", "raw_RN50", "raw_RN101", "raw_RN50x4", "raw_ViT" +} + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def is_interactive(): + import __main__ as main + return not hasattr(main, '__file__') + + +def get_optimizer(optim, verbose=False): + # Bind the optimizer + if optim == 'rms': + if verbose: + print("Optimizer: Using RMSProp") + optimizer = torch.optim.RMSprop + elif optim == 'adam': + if verbose: + print("Optimizer: Using Adam") + optimizer = torch.optim.Adam + elif optim == 'adamw': + if verbose: + print("Optimizer: Using AdamW") + # optimizer = torch.optim.AdamW + optimizer = 'adamw' + elif optim == 'adamax': + if verbose: + print("Optimizer: Using Adamax") + optimizer = torch.optim.Adamax + elif optim == 'sgd': + if verbose: + print("Optimizer: SGD") + optimizer = torch.optim.SGD + else: + assert False, "Please add your optimizer %s in the list." % optim + + return optimizer + + +def parse_args(parse=True, **optional_kwargs): + parser = argparse.ArgumentParser() + + parser.add_argument('--seed', type=int, default=9595, help='random seed') + + # Data Splits + parser.add_argument("--train", default='train') + parser.add_argument("--valid", default='valid') + parser.add_argument("--test", default=None) + parser.add_argument('--test_only', action='store_true') + + parser.add_argument('--submit', action='store_true') + + # Quick experiments + parser.add_argument('--train_topk', type=float, default=-1) + parser.add_argument('--valid_topk', type=float, default=-1) + + # Checkpoint + parser.add_argument('--output', type=str, default='snap/test') + parser.add_argument('--load', type=str, default=None, help='Load the model (usually the fine-tuned model).') + parser.add_argument('--load_lxmert_qa', type=str, default=None) + parser.add_argument('--from_scratch', action='store_true') + parser.add_argument('--project_name', type=str, default="") + parser.add_argument('--run_name', type=str, default="") + + # CPU/GPU + parser.add_argument("--multiGPU", action='store_const', default=False, const=True) + parser.add_argument('--fp16', action='store_true') + parser.add_argument("--distributed", action='store_true') + parser.add_argument("--num_workers", default=0, type=int) + parser.add_argument('--local-rank', type=int, default=-1) + + # Model Config + parser.add_argument('--backbone', type=str, default='t5-base') + parser.add_argument('--tokenizer', type=str, default=None) + + parser.add_argument('--feat_dim', type=float, default=2048) + parser.add_argument('--pos_dim', type=float, default=4) + parser.add_argument('--image_size', type=str, default="(448,448)") + + parser.add_argument('--use_vision', default=True, type=str2bool) + parser.add_argument('--use_vis_order_embedding', default=True, type=str2bool) + parser.add_argument('--use_vis_layer_norm', default=True, type=str2bool) + parser.add_argument('--individual_vis_layer_norm', default=True, type=str2bool) + parser.add_argument('--share_vis_lang_layer_norm', action='store_true') + + parser.add_argument('--n_boxes', type=int, default=36) + parser.add_argument('--max_n_boxes', type=int, default=36) + parser.add_argument('--max_text_length', type=int, default=20) + + parser.add_argument('--additional_visual_embedding_layers', type=int, default=0) + + parser.add_argument('--downsample', action="store_true") + parser.add_argument('--oneddownsample', action="store_true") + parser.add_argument('--expand_vis_embedding', action="store_true") + parser.add_argument('--n_image_tokens', type=int, default=4) + + # Training + parser.add_argument('--batch_size', type=int, default=256) + parser.add_argument('--valid_batch_size', type=int, default=None) + parser.add_argument('--optim', default='adamw') + parser.add_argument('--warmup_ratio', type=float, default=0.05) + parser.add_argument('--weight_decay', type=float, default=0.01) + parser.add_argument('--clip_grad_norm', type=float, default=-1.0) + parser.add_argument('--gradient_accumulation_steps', type=int, default=1) + parser.add_argument('--lr', type=float, default=1e-4) + parser.add_argument('--vis_lr', type=float, default=1e-4) + parser.add_argument('--vis_weight_decay', type=float, default=0.01) + parser.add_argument('--adam_eps', type=float, default=1e-6) + parser.add_argument('--adam_beta1', type=float, default=0.9) + parser.add_argument('--adam_beta2', type=float, default=0.999) + parser.add_argument('--epochs', type=int, default=12) + parser.add_argument('--dropout', type=float, default=0.1) + + parser.add_argument("--losses", default='lm,obj,attr,feat', type=str) + + parser.add_argument('--log_train_accuracy', action='store_true') + + parser.add_argument('--n_ground', type=int, default=1) + parser.add_argument("--wordMaskRate", dest='word_mask_rate', default=0.15, type=float) + parser.add_argument("--objMaskRate", dest='obj_mask_rate',default=0.15, type=float) + + parser.add_argument('--encoder_prompt_len', type=int, default=0) + parser.add_argument('--decoder_prompt_len', type=int, default=0) + parser.add_argument('--use_single_prompt', action="store_true") + parser.add_argument('--unfreeze_language_model', action="store_true") + parser.add_argument('--unfreeze_layer_norms', action="store_true") + parser.add_argument('--use_attn_prefix', action="store_true") + parser.add_argument('--mid_dim', type=int, default=768) + parser.add_argument('--use_adapter', action="store_true") + parser.add_argument('--use_hyperformer', action="store_true") + parser.add_argument('--use_compacter', action="store_true") + parser.add_argument('--use_lradapter', action="store_true") + parser.add_argument('--use_single_adapter', action="store_true") + parser.add_argument('--efficient_unique_hyper_net', action="store_true") + parser.add_argument('--unique_hyper_net', action="store_true") + parser.add_argument('--unfreeze_vis_encoder', action="store_true") + parser.add_argument('--unfreeze_vis_last_layer', action="store_true") + parser.add_argument('--unfreeze_batch_norms', action="store_true") + parser.add_argument('--projected_task_embedding_dim', default=-1, type=int, + help="projected_task_embedding_dim for hyperformer, -1 means using the default value in the config" + ) + + parser.add_argument('--share_down_sampler', action="store_true") + parser.add_argument('--share_up_sampler', action="store_true") + # Compacter + parser.add_argument('--hypercomplex_division', type=int, default=4) + parser.add_argument('--phm_rank', type=int, default=1) + parser.add_argument('--shared_phm_rule', type=str2bool, default=True) + parser.add_argument('--factorized_phm', type=str2bool, default=True) + parser.add_argument('--add_adapter_cross_attn', type=str2bool, default=True) + parser.add_argument('--low_rank_rank', type=int, default=1) + parser.add_argument('--phm_init_range', type=float, default=0.01) + parser.add_argument('--shared_phm_rule_over_tasks', action="store_true") + + parser.add_argument('--vis_pooling_output', action="store_true") + parser.add_argument('--use_vis_adapter', action="store_true") + parser.add_argument('--use_separate_optimizer_for_visual', action="store_true") + parser.add_argument( + '--use_adam_for_visual', action="store_true", help="Use SGD if false" + ) + parser.add_argument('--freeze_ln_statistics', action="store_true") + parser.add_argument('--freeze_bn_statistics', action="store_true") + parser.add_argument('--add_layer_norm_before_adapter', action="store_true") + parser.add_argument('--add_layer_norm_after_adapter', action="store_true") + + parser.add_argument('--vis_adapter_type', type=str, default="middle-bottleneck") + parser.add_argument('--vis_reduction_factor', type=int, default=2) + parser.add_argument('--reduction_factor', type=int, default=16) + parser.add_argument('--use_data_augmentation', action="store_true") + parser.add_argument('--deepspeed', type=str, default=None) + parser.add_argument('--sparse_sample', action="store_true") + parser.add_argument('--remove_bn_vis_adapter', action="store_true") + parser.add_argument('--unfreeze_lm_head', action="store_true") + parser.add_argument('--use_lm_head_adapter', action="store_true") + + # dora + parser.add_argument('--use_dora', action="store_true") + parser.add_argument('--lora_settings', action="store_true") + parser.add_argument('--dora_simple', action="store_true") + + # unfreeze_layer_norm_encoder or decoder + parser.add_argument('--unfreeze_encoder_layer_norms', action="store_true") + parser.add_argument('--unfreeze_decoder_layer_norms', action="store_true") + + + # use bias tuning (bitfit) + parser.add_argument('--unfreeze_bias', action="store_true") + + # use lora + parser.add_argument('--use_lora', action="store_true") + parser.add_argument('--lora_dim', type=int, default=4) + parser.add_argument('--lora_alpha', type=float, default=32) + parser.add_argument('--use_single_lora', action="store_true") + + # Inference + parser.add_argument('--num_beams', type=int, default=1) + parser.add_argument('--gen_max_length', type=int, default=20) + + # Data + parser.add_argument('--caption_only', action='store_true') + parser.add_argument('--coco_only', action='store_true') + parser.add_argument('--caption_cocoonly', default=True, type=str2bool) + + parser.add_argument('--do_lower_case', action='store_true') + parser.add_argument('--oscar_tags', action='store_true') + + parser.add_argument('--prefix', type=str, default=None) + + parser.add_argument('--prompt', type=str, default="vqa: ") + parser.add_argument('--post_prompt', type=str, default="") + + parser.add_argument('--feature_type', type=str, default="butd", choices=feature_types) + + # VQA + parser.add_argument("--raw_label", action='store_true') + parser.add_argument("--answer_normalize", action='store_true') + parser.add_argument("--classifier", action='store_true') + parser.add_argument("--test_answerable", action='store_true') + + # Classification + parser.add_argument('--cls_task', type=str, default='tinyimagenet') + + # Multitask + parser.add_argument("--multitask_sampling", type=str, default='roundrobin') + parser.add_argument("--tasks", type=str, default='') + parser.add_argument("--use_tasks_prompts", action="store_true") + parser.add_argument("--testing", action="store_true") + parser.add_argument("--track_z", action="store_true") + parser.add_argument("--lambda_z", type=float, default=0.001) + + # Etc. + parser.add_argument('--comment', type=str, default='') + parser.add_argument("--dry", action='store_true') + + # ReFT + parser.add_argument('--layers', type=str, default="0;1;2;3;4;5") + parser.add_argument('--reft_rank', type=int, default=1) + parser.add_argument('--reft_image_rank', type=int, default=-1) + parser.add_argument('--reft_dropout', type=float, default=0.0) + parser.add_argument('--reft_image_dropout', type=float, default=0.0) + parser.add_argument('--positions', type=str, default="f1+l1") + parser.add_argument('--image_positions', type=str, default="f1+l1") + parser.add_argument('--share_weights', action='store_true') + + + # Parse the arguments. + if parse: + args = parser.parse_args() + # For interative engironmnet (ex. jupyter) + else: + args = parser.parse_known_args()[0] + + # Namespace => Dictionary + kwargs = vars(args) + kwargs.update(optional_kwargs) + + args = Config(**kwargs) + + # Bind optimizer class. + verbose = False + args.optimizer = get_optimizer(args.optim, verbose=verbose) + + # Set seeds + torch.manual_seed(args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + + return args + + +class Config(object): + def __init__(self, **kwargs): + """Configuration Class: set kwargs as class attributes with setattr""" + for k, v in kwargs.items(): + setattr(self, k, v) + + @property + def config_str(self): + return pprint.pformat(self.__dict__) + + def __repr__(self): + """Pretty-print configurations in alphabetical order""" + config_str = 'Configurations\n' + config_str += self.config_str + return config_str + + def save(self, path): + with open(path, 'w') as f: + yaml.dump(self.__dict__, f, default_flow_style=False) + + @classmethod + def load(cls, path): + with open(path, 'r') as f: + kwargs = yaml.load(f) + + return Config(**kwargs) + + +if __name__ == '__main__': + args = parse_args(True) diff --git a/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/trainer_base.py b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/trainer_base.py new file mode 100644 index 0000000..bdf2eff --- /dev/null +++ b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/trainer_base.py @@ -0,0 +1,364 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import torch.backends.cudnn as cudnn +import torch.multiprocessing as mp +import os +import re +import collections +from pathlib import Path +from packaging import version + +import numpy as np +from tqdm import tqdm +import torch +import torch.nn as nn +import logging +import shutil +from pprint import pprint + +from utils import load_state_dict, LossMeter, set_global_logging_level +import wandb +from pprint import pformat +import modeling_bart + +import math + +proj_dir = Path(__file__).resolve().parent.parent + +_use_native_amp = False +_use_apex = False + +# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex +if version.parse(torch.__version__) < version.parse("1.6"): + from transormers.file_utils import is_apex_available + if is_apex_available(): + from apex import amp + _use_apex = True +else: + _use_native_amp = True + from torch.cuda.amp import autocast + +class TrainerBase(object): + def __init__(self, args, train_loader=None, val_loader=None, test_loader=None, train=True): + self.args = args + + self.train_loader = train_loader + self.val_loader = val_loader + self.test_loader = test_loader + + self.verbose = True + + if self.args.tokenizer is None: + self.args.tokenizer = self.args.backbone + + if not self.verbose: + set_global_logging_level(logging.ERROR, ["transformers"]) + + self.deepspeed = args.deepspeed + + def create_config(self): + from transformers import BartConfig + + if 'bart' in self.args.backbone: + config_class = BartConfig + else: + return None + + config = config_class.from_pretrained(self.args.backbone) + + args = self.args + + + for k, v in vars(args).items(): + setattr(config, k, v) + + config.feat_dim = args.feat_dim + config.pos_dim = args.pos_dim + config.n_images = 2 + config.n_boxes = args.n_boxes + config.n_image_tokens = args.n_image_tokens + config.downsample = args.downsample + config.oneddownsample = args.oneddownsample + config.sparse_sample = args.sparse_sample + + config.mid_dim = args.mid_dim + config.reduction_factor = args.reduction_factor + + config.use_hyperformer = args.use_hyperformer + config.use_compacter = args.use_compacter + + tasks = re.split("[, ]+", args.tasks) # tranform to list + config.adapter_config = None + + config.dropout_rate = args.dropout + config.dropout = args.dropout + config.attention_dropout = args.dropout + config.activation_dropout = args.dropout + + config.losses = args.losses + + return config + + def create_model(self, model_class, config=None, **kwargs): + print(f'Building Model at GPU {self.args.gpu}') + + model_name = self.args.backbone + + model = model_class.from_pretrained( + model_name, + config=config, + **kwargs + ) + return model + + def print_trainable_params_percentage(self, model): + + orig_param_size = sum(p.numel() for p in model.parameters()) + + def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + trainable_size = count_parameters(model) + + percentage = trainable_size / orig_param_size * 100 + + print(f"Trainable param percentage: {percentage:.2f}% ({trainable_size}/{orig_param_size})") + + return percentage + + def freeze_whole_model(self): + for n, p in self.model.named_parameters(): + p.requires_grad = False + + def partial_eval(self): + # the purpose is to fix some of the norm statistics + model = self.model + + def LM_LN_eval(model): + for name, sub_module in model.named_modules(): + if isinstance(sub_module, (modeling_bart.JointEncoder, modeling_bart.BartDecoder)): + # print(f"Change {name} to eval mode...") + sub_module.eval() + + def only_LN_eval(model): + for name, sub_module in model.named_modules(): + if "visual_embedding" in name: # skip trainable parameters + continue + if isinstance(sub_module, (nn.LayerNorm)): + # print(f"Change {name} to eval mode...") + sub_module.eval() # freeze the LN statistics and dropout + + def only_BN_eval(model): + for name, sub_module in model.named_modules(): + if isinstance(sub_module, (nn.BatchNorm2d)): + # print(f"Change {name} to eval mode...") + sub_module.eval() # freeze the LN statistics and dropout + + if self.args.freeze_ln_statistics: + only_LN_eval(model) + + if self.args.freeze_bn_statistics: + only_BN_eval(model) + + def unfreeze_parameters(self): + + + targets = ["visual_embedding"] + + # unfreeze the parameters in targets anyway + for n, p in self.model.named_parameters(): + if any(t in n for t in targets): + p.requires_grad = True + print(f"{n} is trainable...") + + if self.args.unfreeze_language_model: + targets = ["lm_head", "shared"] + for n, p in self.model.named_parameters(): + if any(t in n for t in targets): + p.requires_grad = True + print(f"{n} is trainable...") + for name, sub_module in self.model.named_modules(): + if isinstance(sub_module, (modeling_bart.JointEncoder, modeling_bart.BartDecoder)): + for param_name, param in sub_module.named_parameters(): + print(f"{param_name} is trainable...") + param.requires_grad = True + + if self.args.unfreeze_lm_head: + targets = ["lm_head", "shared"] # shared and lm_head share the same weight + for n, p in self.model.named_parameters(): + if any(t in n for t in targets): + p.requires_grad = True + print(f"{n} is trainable...") + + if self.args.unfreeze_bias: + targets = ["bias"] + # unfreeze the parameters in targets anyway + for n, p in self.model.named_parameters(): + if any(t in n for t in targets): + p.requires_grad = True + print(f"{n} is trainable...({p.numel()})") + + + if self.args.unfreeze_encoder_layer_norms: + target1 = "encoder." + target2 = "layer_norm" + target3 = "layernorm" + # unfreeze the parameters in targets anyway + for n, p in self.model.named_parameters(): + # if any(t in n for t in targets): + if target1 in n and (target2 in n or target3 in n): + p.requires_grad = True + print(f"{n} is trainable...({p.numel()})") + + if self.args.unfreeze_decoder_layer_norms: + target1 = "decoder." + target2 = "layer_norm" + target3 = "layernorm" + # unfreeze the parameters in targets anyway + for n, p in self.model.named_parameters(): + # if any(t in n for t in targets): + if target1 in n and (target2 in n or target3 in n): + p.requires_grad = True + print(f"{n} is trainable...({p.numel()})") + + + + for name, sub_module in self.model.named_modules(): + + if self.args.unfreeze_layer_norms: + if isinstance(sub_module, (nn.LayerNorm)): + print(f"{name} is trainable...") + for param_name, param in sub_module.named_parameters(): + param.requires_grad = True + + if self.args.unfreeze_batch_norms: + if isinstance(sub_module, (nn.BatchNorm2d)): + print(f"{name} is trainable...") + for param_name, param in sub_module.named_parameters(): + param.requires_grad = True + print(self.model) + + def create_tokenizer(self, **kwargs): + from transformers import BartTokenizer, BartTokenizerFast + + if 'bart' in self.args.tokenizer: + tokenizer_class = BartTokenizer + # tokenizer_class = BartTokenizerFast + + tokenizer_name = self.args.backbone + + tokenizer = tokenizer_class.from_pretrained( + tokenizer_name, + max_length=self.args.max_text_length, + do_lower_case=self.args.do_lower_case, + **kwargs + ) + + return tokenizer + + def create_optimizer_and_scheduler(self): + if self.verbose: + print('Building Optimizer') + + lr_scheduler = None + + from transformers.optimization import AdamW, get_linear_schedule_with_warmup + + no_decay = ["bias", "LayerNorm.weight"] + + if 'adamw' in self.args.optim: + optimizer_grouped_parameters = [ + { + "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": self.args.weight_decay, + }, + { + "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + optim = AdamW(optimizer_grouped_parameters, + lr=self.args.lr, eps=self.args.adam_eps) + + else: + optimizer_grouped_parameters = [ + { + "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": self.args.weight_decay, + }, + { + "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optim = self.args.optimizer(optimizer_grouped_parameters, self.args.lr) + + batch_per_epoch = len(self.train_loader) + t_total = batch_per_epoch // self.args.gradient_accumulation_steps * self.args.epochs + warmup_ratio = self.args.warmup_ratio + warmup_iters = int(t_total * warmup_ratio) + if self.verbose: + print("Batch per epoch: %d" % batch_per_epoch) + print("Total Iters: %d" % t_total) + print('Warmup ratio:', warmup_ratio) + print("Warm up Iters: %d" % warmup_iters) + + lr_scheduler = get_linear_schedule_with_warmup(optim, warmup_iters, t_total) + + return optim, lr_scheduler + + def load_checkpoint(self, ckpt_path): + state_dict = load_state_dict(ckpt_path, 'cpu') + + results = self.model.load_state_dict(state_dict, strict=False) + if self.verbose: + print('Model loaded from ', ckpt_path) + pprint(results) + + def init_weights(self): + + def init_bert_weights(module): + """ Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=1) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + self.model.apply(init_bert_weights) + self.model.init_weights() + + def predict(self): + pass + + def evaluate(self): + pass + + def save(self, name): + if not os.path.isdir(self.args.output): + os.makedirs(self.args.output, exist_ok=True) + + if self.deepspeed: + self.model.save_checkpoint(self.args.output, name) + else: + torch.save(self.model.state_dict(), os.path.join(self.args.output, "%s.pth" % name)) + + def load(self, path, loc=None): + if loc is None and hasattr(self.args, 'gpu'): + loc = f'cuda:{self.args.gpu}' + state_dict = torch.load("%s.pth" % path, map_location=loc) + + results = self.model.load_state_dict(state_dict, strict=False) + if self.verbose: + print('Model loaded from ', path) + pprint(results) diff --git a/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/utils.py b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/utils.py new file mode 100644 index 0000000..18c4de3 --- /dev/null +++ b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/utils.py @@ -0,0 +1,88 @@ +import re +import numpy as np +import torch +import torch.distributed as dist +import collections +import logging + +def get_area(pos): + """ + Args + pos: [B, N, 4] + (x1, x2, y1, y2) + + Return + area : [B, N] + """ + # [B, N] + height = pos[:, :, 3] - pos[:, :, 2] + width = pos[:, :, 1] - pos[:, :, 0] + area = height * width + return area + +def get_relative_distance(pos): + """ + Args + pos: [B, N, 4] + (x1, x2, y1, y2) + + Return + out : [B, N, N, 4] + """ + # B, N = pos.size()[:-1] + + # [B, N, N, 4] + relative_distance = pos.unsqueeze(1) - pos.unsqueeze(2) + + return relative_distance + + +class LossMeter(object): + def __init__(self, maxlen=100): + """Computes and stores the running average""" + self.vals = collections.deque([], maxlen=maxlen) + + def __len__(self): + return len(self.vals) + + def update(self, new_val): + self.vals.append(new_val) + + @property + def val(self): + return sum(self.vals) / len(self.vals) + + def __repr__(self): + return str(self.val) + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def load_state_dict(state_dict_path, loc='cpu'): + state_dict = torch.load(state_dict_path, map_location=loc) + # Change Multi GPU to single GPU + original_keys = list(state_dict.keys()) + for key in original_keys: + if key.startswith("module."): + new_key = key[len("module."):] + state_dict[new_key] = state_dict.pop(key) + return state_dict + + +def set_global_logging_level(level=logging.ERROR, prefices=[""]): + """ + Override logging levels of different modules based on their name as a prefix. + It needs to be invoked after the modules have been loaded so that their loggers have been initialized. + + Args: + - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR + - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional. + Default is `[""]` to match all active loggers. + The match is a case-sensitive `module_name.startswith(prefix)` + """ + prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })') + for name in logging.root.manager.loggerDict: + if re.match(prefix_re, name): + logging.getLogger(name).setLevel(level) diff --git a/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/vqa_clip_data.py b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/vqa_clip_data.py new file mode 100644 index 0000000..65d5e11 --- /dev/null +++ b/examples/vlbart/ReftDora/image_video_text_understanding/VL-T5/src/vqa_clip_data.py @@ -0,0 +1,691 @@ +from torch.utils.data import DataLoader, Dataset, Sampler +from pathlib import Path +from collections import defaultdict +import json +import random +from multiprocessing import Pool +import h5py +import pickle +import math +from tqdm import tqdm +import torch +import numpy as np +from copy import deepcopy +import re + +from torch.utils.data.distributed import DistributedSampler + +from transformers import BartTokenizer + +project_dir = Path(__file__).resolve().parent.parent +workspace_dir = project_dir.parent +dataset_dir = workspace_dir.joinpath('datasets/').resolve() +coco_dir = dataset_dir.joinpath('COCO') +vg_dir = dataset_dir.joinpath('VG') +coco_img_dir = coco_dir.joinpath('images/') +coco_feature_dir = coco_dir.joinpath('clip_features') +vqa_dir = dataset_dir.joinpath('vqa') + + +class VQAFineTuneDataset(Dataset): + def __init__(self, split='train', raw_dataset=None, rank=-1, topk=-1, verbose=True, args=None, mode='train'): + super().__init__() + + self.raw_dataset = raw_dataset + self.topk = topk + self.verbose = verbose + self.args = args + + self.mode = mode + + # Loading datasets to data + self.sources = split.split(',') + if self.verbose: + print('Data sources: ', self.sources) + + if 'bart' in self.args.backbone: + self.tokenizer = BartTokenizer.from_pretrained( + args.backbone, + # max_length=self.args.max_text_length, + do_lower_case=self.args.do_lower_case) + + if args.use_vis_order_embedding: + additional_special_tokens = [f'' for i in range(100-1, -1, -1)] + \ + [f'' for i in range(100-1, -1, -1)] + special_tokens_dict = {'additional_special_tokens': additional_special_tokens} + num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict) + + self.answer_normalizer = VQAEvaluator() + + self.img_ids_to_source = {} + data_info_dicts = [] + for source in self.sources: + data_info_path = dataset_dir.joinpath(f'vqa/{source}.json') + with open(data_info_path) as f: + _data_info_dicts = json.load(f) + for _d in _data_info_dicts: + if 'vg_qa_full' == source: + self.img_ids_to_source[_d['img_id']] = 'vg' + elif 'train2014' in _d['img_id']: + self.img_ids_to_source[_d['img_id']] = 'train2014' + elif 'val2014' in _d['img_id']: + self.img_ids_to_source[_d['img_id']] = 'val2014' + else: + self.img_ids_to_source[_d['img_id']] = source + _d['source'] = source + + data_info_dicts.extend(_data_info_dicts) + if self.verbose: + print(f"Loaded {len(_data_info_dicts)} data from", source) + + data = data_info_dicts + + self.n_gpus = torch.cuda.device_count() + + self.rank = rank + + if isinstance(self.topk, float) and (0 < self.topk <= 1): + used_samples = int(self.topk * len(data)) + data = random.sample(data, used_samples) + if self.verbose: + print(f"Use only {len(data)} data") + + elif self.topk > 0: + data = data[:int(self.topk)] + if self.verbose: + print(f"Use only {len(data)} data") + + self.data = data + + if self.verbose: + print("# all sentences:", len(self.data)) + + self.n_boxes = args.n_boxes + + self.feature_type = self.args.feature_type + + if self.args.vis_pooling_output: + self.source_to_h5 = { + _type: coco_feature_dir.joinpath(f"data_clip_{_type}_fc") + for _type in ["RN50", "RN101", "RN50x4", "ViT"] + } + else: + self.source_to_h5 = { + _type: coco_feature_dir.joinpath(f"data_clip_{_type}_att") + for _type in ["RN50", "RN101", "RN50x4", "ViT"] + } + + self.h5_path = self.source_to_h5[self.feature_type] + + # self.source_to_h5 = { + # 'train': coco_feature_dir.joinpath(f'train2014_obj36.h5'), + # 'minival': coco_feature_dir.joinpath(f'val2014_obj36.h5'), + # 'nominival': coco_feature_dir.joinpath(f'val2014_obj36.h5'), + # 'test': coco_feature_dir.joinpath(f'test2015_obj36.h5'), + + # 'vg': dataset_dir.joinpath('VG/features').joinpath('vg_gqa_obj36.h5'), + + # 'train2014': coco_feature_dir.joinpath(f'train2014_obj36.h5'), + # 'val2014': coco_feature_dir.joinpath(f'val2014_obj36.h5'), + # } + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + + out_dict = {} + out_dict['args'] = self.args + + datum = self.data[idx] + + ###### Image ###### + if self.args.use_vision: + img_id = datum['img_id'] + out_dict['img_id'] = img_id + + path = self.h5_path.joinpath(f"{img_id}.h5") + + with h5py.File(path, 'r') as f: + # feats = np.zeros(shape=(self.n_boxes, 2048), dtype=np.float32) + # try: + # f[f'{img_id}/features'].read_direct(feats) + # except KeyError: + # print('img_id', img_id) + # print(datum) + # exit() + + # feats = torch.from_numpy(feats) + feats = f[f"{img_id}/features"][...] + out_dict['vis_feats'] = feats # (L, D) + + # Normalize the boxes (to 0 ~ 1) + # img_h = f[f'{img_id}/img_h'][()] + # img_w = f[f'{img_id}/img_w'][()] + # boxes = f[f'{img_id}/boxes'][()] # (x1, y1, x2, y2) + # boxes[:, (0, 2)] /= img_w + # boxes[:, (1, 3)] /= img_h + # np.testing.assert_array_less(boxes, 1+1e-5) + # # np.testing.assert_array_less(boxes, 1+5e-2) + # np.testing.assert_array_less(-boxes, 0+1e-5) + # boxes = torch.from_numpy(boxes) + + # boxes.clamp_(min=0.0, max=1.0) + + boxes = torch.zeros(feats.shape[0], 4) # (L, 4) + + out_dict['boxes'] = boxes + + ###### Text ##### + # caption = datum['caption'] + if 'sent' in datum: + sent = datum['sent'] + elif 'question' in datum: + sent = datum['question'] + + input_ids = self.tokenizer.encode(f'{self.args.prompt}{sent}{self.args.post_prompt}', max_length=20, truncation=True) + + question_id = datum['question_id'] + out_dict['question_id'] = question_id + + + out_dict['sent'] = sent + out_dict['input_ids'] = torch.LongTensor(input_ids) + out_dict['input_length'] = len(input_ids) + # out_dict['target_ids'] = torch.LongTensor(target_ids) + # out_dict['target_length'] = len(target_ids) + + if 'is_topk_optimal' in datum: + out_dict['is_topk_optimal'] = datum['is_topk_optimal'] + + if 'label' in datum: + label = datum['label'] + out_dict['label'] = label + + # 3129 topk answers + if self.args.classifier: + target = torch.zeros(self.raw_dataset.num_answers) + for ans, score in label.items(): + target[self.raw_dataset.ans2label[ans]] = score + out_dict['target'] = target + + elif self.args.raw_label: + + # 10 raw answers + # ex) 'answers': [{'answer': 'net', 'answer_confidence': 'maybe', 'answer_id': 1}, + # {'answer': 'net', 'answer_confidence': 'yes', 'answer_id': 2}, + # {'answer': 'net', 'answer_confidence': 'yes', 'answer_id': 3}, + # {'answer': 'netting', 'answer_confidence': 'yes', 'answer_id': 4}, + # {'answer': 'net', 'answer_confidence': 'yes', 'answer_id': 5}, + # {'answer': 'net', 'answer_confidence': 'yes', 'answer_id': 6}, + # {'answer': 'mesh', 'answer_confidence': 'maybe', 'answer_id': 7}, + # {'answer': 'net', 'answer_confidence': 'yes', 'answer_id': 8}, + # {'answer': 'net', 'answer_confidence': 'yes', 'answer_id': 9}, + # {'answer': 'net', 'answer_confidence': 'yes', 'answer_id': 10}], + + answers = datum['answers'] + answer = random.choice(answers)['answer'] + + if self.args.answer_normalize: + answer = self.answer_normalizer.normalize_answer(answer) + + score = int(len(answers) > 0) + + out_dict['answer'] = answer + out_dict['score'] = score + out_dict['all_answers'] = [a['answer'] for a in answers] + + target_ids = self.tokenizer.encode(answer, max_length=10, truncation=True) + + out_dict['target_ids'] = torch.LongTensor(target_ids) + out_dict['target_length'] = len(target_ids) + + else: + # https://github.com/airsplay/lxmert/blob/master/src/pretrain/lxmert_pretrain.py#L191 + + answers = [] + scores = [] + for a, s in label.items(): + answers.append(a) + scores.append(s) + + score_sum = sum(scores) + + if score_sum == 0: + answer = '' + score = 0. + else: + prob = [score / score_sum for score in scores] + choice = np.random.multinomial(1, prob).argmax() + answer = answers[choice] + score = scores[choice] + assert len(answer) > 0, (sent, label, choice, answer) + + out_dict['answer'] = answer + out_dict['score'] = score + out_dict['all_answers'] = answers + + + target_ids = self.tokenizer.encode(answer, max_length=10, truncation=True) + + out_dict['target_ids'] = torch.LongTensor(target_ids) + out_dict['target_length'] = len(target_ids) + + return out_dict + + + def collate_fn(self, batch): + batch_entry = {} + + args = batch[0]['args'] + + B = len(batch) + + S_W_L = max(entry['input_length'] for entry in batch) + input_ids = torch.ones(B, S_W_L, dtype=torch.long) * self.tokenizer.pad_token_id + + if args.use_vision: + V_L = len(batch[0]['boxes']) + feat_dim = batch[0]['vis_feats'].shape[-1] + + boxes = torch.zeros(B, V_L, 4, dtype=torch.float) + vis_feats = torch.zeros(B, V_L, feat_dim, dtype=torch.float) + + if 'target' in batch[0]: + # targets = [] + targets = torch.zeros(B, len(batch[0]['target']), dtype=torch.float) + if 'target_ids' in batch[0]: + T_W_L = max(entry['target_length'] for entry in batch) + target_ids = torch.ones(B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id + + sentences = [] + question_ids = [] + answers = [] + all_answers = [] + img_ids = [] + img_paths = [] + labels = [] + scores = [] + is_topk_optimal = [] + + for i, entry in enumerate(batch): + input_ids[i, :entry['input_length']] = entry['input_ids'] + + if args.use_vision: + boxes[i] += entry['boxes'] + vis_feats[i] += entry['vis_feats'] + # img_ids.append(entry['img_id']) + # img_paths.append(entry['img_path']) + + if 'target_ids' in entry: + target_ids[i, :entry['target_length']] = entry['target_ids'] + + if 'target' in entry: + targets[i] += entry['target'] + # targets.append(entry['target']) + + sentences.append(entry['sent']) + question_ids.append(entry['question_id']) + if 'answer' in entry: + answers.append(entry['answer']) + if 'all_answers' in entry: + all_answers.append(entry['all_answers']) + if 'score' in entry: + scores.append(entry['score']) + + if 'label' in entry: + labels.append(entry['label']) + + if 'is_topk_optimal' in entry: + is_topk_optimal.append(entry['is_topk_optimal']) + + batch_entry['input_ids'] = input_ids + if 'target_ids' in batch[0]: + word_mask = target_ids != self.tokenizer.pad_token_id + target_ids[~word_mask] = -100 + batch_entry['target_ids'] = target_ids + if 'target' in batch[0]: + # targets = torch.stack(targets, dim=0) + batch_entry['targets'] = targets + + if args.use_vision: + batch_entry['boxes'] = boxes + batch_entry['vis_feats'] = vis_feats + # batch_entry['img_id'] = img_ids + # batch_entry['img_paths'] = img_paths + + batch_entry['sent'] = sentences + batch_entry['question_ids'] = question_ids + batch_entry['answers'] = answers + batch_entry['all_answers'] = all_answers + batch_entry['scores'] = torch.FloatTensor(scores) + batch_entry['labels'] = labels + + batch_entry['args'] = args + batch_entry['task'] = 'vqa' + + return batch_entry + + +def get_loader(args, split='karpathy_train', mode='train', + batch_size=32, workers=4, distributed=False, gpu=0, topk=-1): + + verbose = (gpu == 0) + + _dset = VQADataset(split, verbose) + + dataset = VQAFineTuneDataset( + split, + raw_dataset=_dset, + rank=gpu, + topk=topk, + verbose=verbose, + args=args, + mode=mode) + + if distributed: + sampler = DistributedSampler(dataset) + else: + sampler = None + + if mode == 'train': + loader = DataLoader( + dataset, batch_size=batch_size, shuffle=(sampler is None), + num_workers=workers, pin_memory=True, sampler=sampler, + collate_fn=dataset.collate_fn) + else: + loader = DataLoader( + dataset, + batch_size=batch_size, + num_workers=workers, pin_memory=True, + sampler=sampler, + shuffle=None if (sampler is not None) else False, + collate_fn=dataset.collate_fn, + drop_last=False) + + if verbose: + loader.evaluator = VQAEvaluator(_dset) + + loader.task = 'vqa' + + return loader + + +class VQADataset: + """ + A VQA data example in json file: + { + "answer_type": "other", + "img_id": "COCO_train2014_000000458752", + "label": { + "net": 1 + }, + "question_id": 458752000, + "question_type": "what is this", + "sent": "What is this photo taken looking through?" + } + """ + + def __init__(self, splits: str, verbose=True): + self.name = splits + self.splits = splits.split(',') + + with open(dataset_dir.joinpath(f'vqa/v2_mscoco_train2014_annotations.json')) as f: + train2014_data = json.load(f) + with open(dataset_dir.joinpath(f'vqa/v2_mscoco_val2014_annotations.json')) as f: + val2014_data = json.load(f) + train2014_id2datum = {} + for datum in train2014_data['annotations']: + qid = datum['question_id'] + train2014_id2datum[qid] = datum + val2014_id2datum = {} + for datum in val2014_data['annotations']: + qid = datum['question_id'] + val2014_id2datum[qid] = datum + self.id2datum_gt = {**train2014_id2datum, **val2014_id2datum} + + # Loading datasets + self.data = [] + for split in self.splits: + self.data.extend( + json.load(open(vqa_dir.joinpath("%s.json" % split)))) + + if verbose: + print("Load %d data from split(s) %s." % + (len(self.data), self.name)) + + # Convert list to dict (for evaluation) + self.id2datum = { + datum['question_id']: datum + for datum in self.data + } + + # Topk Answers + self.ans2label = json.load( + open(vqa_dir.joinpath("trainval_ans2label.json"))) + self.label2ans = json.load( + open(vqa_dir.joinpath("trainval_label2ans.json"))) + assert len(self.ans2label) == len(self.label2ans) + + if verbose: + print('# Answers:', len(self.ans2label)) + + @property + def num_answers(self): + return len(self.ans2label) + + def __len__(self): + return len(self.data) + + +class VQAEvaluator: + def __init__(self, dataset: VQADataset = None): + self.dataset = dataset + + """https://github.com/GT-Vision-Lab/VQA/blob/master/PythonEvaluationTools/vqaEvaluation/vqaEval.py""" + + self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \ + "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \ + "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \ + "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \ + "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \ + "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \ + "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \ + "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \ + "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \ + "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \ + "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \ + "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \ + "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \ + "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \ + "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \ + "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \ + "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \ + "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \ + "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \ + "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \ + "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \ + "youll": "you'll", "youre": "you're", "youve": "you've"} + + self.manualMap = { 'none': '0', + 'zero': '0', + 'one': '1', + 'two': '2', + 'three': '3', + 'four': '4', + 'five': '5', + 'six': '6', + 'seven': '7', + 'eight': '8', + 'nine': '9', + 'ten': '10' + } + + self.articles = ['a', + 'an', + 'the' + ] + + self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") + self.commaStrip = re.compile("(\d)(\,)(\d)") + self.punct = [';', r"/", '[', ']', '"', '{', '}', + '(', ')', '=', '+', '\\', '_', '-', + '>', '<', '@', '`', ',', '?', '!'] + + self.n = 2 + + def evaluate(self, quesid2ans: dict): + score = 0. + for quesid, ans in quesid2ans.items(): + datum = self.dataset.id2datum[quesid] + label = datum['label'] + if ans in label: + score += label[ans] + return score / len(quesid2ans) + + def dump_result(self, quesid2ans: dict, path): + """ + Dump results to a json file, which could be submitted to the VQA online evaluation. + VQA json file submission requirement: + results = [result] + result = { + "question_id": int, + "answer": str + } + :param quesid2ans: dict of quesid --> ans + :param path: The desired path of saved file. + """ + with open(path, 'w') as f: + result = [] + for ques_id, ans in quesid2ans.items(): + result.append({ + 'question_id': ques_id, + 'answer': ans + }) + json.dump(result, f, indent=4, sort_keys=True) + + def evaluate_raw(self, quesid2ans: dict, is_topk_optimal=None): + """https://github.com/GT-Vision-Lab/VQA/blob/master/PythonEvaluationTools/vqaEvaluation/vqaEval.py""" + + gts = self.dataset.id2datum_gt + + self.accuracy = {} + self.evalQA = {} + self.evalQuesType = {} + self.evalAnsType = {} + + accQA = [] + accQuesType = {} + accAnsType = {} + + # print("Computing accuracy") + + for quesId, resAns in tqdm(quesid2ans.items(), total=len(quesid2ans), ncols=80): + + quesId = int(quesId) + + datum = self.dataset.id2datum[quesId] + + if is_topk_optimal is None: + pass + elif 'is_topk_optimal' in datum: + if datum['is_topk_optimal'] != is_topk_optimal: + continue + + resAns = resAns.replace('\n', ' ') + resAns = resAns.replace('\t', ' ') + resAns = resAns.strip() + resAns = self.processPunctuation(resAns) + resAns = self.processDigitArticle(resAns) + + gtAcc = [] + gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']] + if len(set(gtAnswers)) > 1: + for ansDic in gts[quesId]['answers']: + ansDic['answer'] = self.processPunctuation(ansDic['answer']) + for gtAnsDatum in gts[quesId]['answers']: + otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum] + matchingAns = [item for item in otherGTAns if item['answer']==resAns] + acc = min(1, float(len(matchingAns))/3) + gtAcc.append(acc) + quesType = gts[quesId]['question_type'] + ansType = gts[quesId]['answer_type'] + avgGTAcc = float(sum(gtAcc))/len(gtAcc) + accQA.append(avgGTAcc) + if quesType not in accQuesType: + accQuesType[quesType] = [] + accQuesType[quesType].append(avgGTAcc) + if ansType not in accAnsType: + accAnsType[ansType] = [] + accAnsType[ansType].append(avgGTAcc) + + self.setEvalQA(quesId, avgGTAcc) + self.setEvalQuesType(quesId, quesType, avgGTAcc) + self.setEvalAnsType(quesId, ansType, avgGTAcc) + + + if len(accQA) == 0: + return { + 'overall': 0, + 'perQuestionType': {}, + 'perAnswerType': {} + } + else: + self.setAccuracy(accQA, accQuesType, accAnsType) + + return self.accuracy + + def normalize_answer(self, resAns): + resAns = resAns.replace('\n', ' ') + resAns = resAns.replace('\t', ' ') + resAns = resAns.strip() + resAns = self.processPunctuation(resAns) + resAns = self.processDigitArticle(resAns) + resAns = resAns.replace(',', '') + return resAns + + def processPunctuation(self, inText): + outText = inText + for p in self.punct: + if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None): + outText = outText.replace(p, '') + else: + outText = outText.replace(p, ' ') + outText = self.periodStrip.sub("", + outText, + re.UNICODE) + return outText + + def processDigitArticle(self, inText): + outText = [] + tempText = inText.lower().split() + for word in tempText: + word = self.manualMap.setdefault(word, word) + if word not in self.articles: + outText.append(word) + else: + pass + for wordId, word in enumerate(outText): + if word in self.contractions: + outText[wordId] = self.contractions[word] + outText = ' '.join(outText) + return outText + + def setEvalQA(self, quesId, acc): + self.evalQA[quesId] = round(100*acc, self.n) + + def setEvalQuesType(self, quesId, quesType, acc): + if quesType not in self.evalQuesType: + self.evalQuesType[quesType] = {} + self.evalQuesType[quesType][quesId] = round(100*acc, self.n) + + def setEvalAnsType(self, quesId, ansType, acc): + if ansType not in self.evalAnsType: + self.evalAnsType[ansType] = {} + self.evalAnsType[ansType][quesId] = round(100*acc, self.n) + + def setAccuracy(self, accQA, accQuesType, accAnsType): + self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n) + self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType} + self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType} + diff --git a/examples/vlbart/ReftDora/image_video_text_understanding/download_backbones.py b/examples/vlbart/ReftDora/image_video_text_understanding/download_backbones.py new file mode 100644 index 0000000..3ed000f --- /dev/null +++ b/examples/vlbart/ReftDora/image_video_text_understanding/download_backbones.py @@ -0,0 +1,16 @@ + +from transformers import T5ForConditionalGeneration, T5Tokenizer +from transformers import BartForConditionalGeneration, BartTokenizer + +if __name__ == '__main__': + + + print('Downloading checkpoints if not cached') + print('T5-base') + model = T5ForConditionalGeneration.from_pretrained('t5-base', cache_dir="/nlp/scr/peterwz/.cache") + tokenizer = T5Tokenizer.from_pretrained('t5-base') + print('BART-base') + tokenizer = BartTokenizer.from_pretrained("facebook/bart-base", cache_dir="/nlp/scr/peterwz/.cache") + model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") + print('Done!') + diff --git a/pyreft/dataset.py b/pyreft/dataset.py index cf2fde5..bc170da 100644 --- a/pyreft/dataset.py +++ b/pyreft/dataset.py @@ -79,6 +79,9 @@ def get_intervention_locations(**kwargs): _first_n, _last_n = kwargs["first_n"], kwargs["last_n"] num_interventions = kwargs["num_interventions"] pad_mode = kwargs["pad_mode"] if "pad_mode" in kwargs else "first" + last_offset = kwargs["last_offset"] if "last_offset" in kwargs else 0 + last_position += last_offset + first_n = min(last_position // 2, _first_n) last_n = min(last_position // 2, _last_n) @@ -252,7 +255,6 @@ def compute_intervention_and_subspaces(self, id: int, data_item, result: dict, l return result - class ReftRawDataset(Dataset): def __init__( diff --git a/pyreft/interventions.py b/pyreft/interventions.py index 3a25f7a..d50b758 100644 --- a/pyreft/interventions.py +++ b/pyreft/interventions.py @@ -36,12 +36,19 @@ def __init__(self, **kwargs): super().__init__(**kwargs, keep_last_dim=True) rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"], init_orth=True) self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer, orthogonal_map='householder') + self.dtype = kwargs["dtype"] if "dtype" in kwargs else torch.bfloat16 + # self.dtype = torch.float32 self.learned_source = torch.nn.Linear( self.embed_dim, kwargs["low_rank_dimension"]).to( - kwargs["dtype"] if "dtype" in kwargs else torch.bfloat16) + self.convert_type(self.dtype) + ) self.dropout = torch.nn.Dropout(kwargs["dropout"] if "dropout" in kwargs else 0.0) self.act_fn = ACT2FN["linear"] if "act_fn" not in kwargs or kwargs["act_fn"] is None else ACT2FN[kwargs["act_fn"]] + def convert_type(self, dtype): + return torch.bfloat16 if dtype == "bfloat16" else dtype + + def forward( self, base, source=None, subspaces=None ): diff --git a/pyreft/reft_model.py b/pyreft/reft_model.py index 9ff6e0d..8534221 100644 --- a/pyreft/reft_model.py +++ b/pyreft/reft_model.py @@ -54,3 +54,25 @@ def print_trainable_parameters(self): f"model params: {all_model_parameters:,d} || trainable%: {100 * total_trainable_parameters / all_model_parameters}" ) + def unfreeze_intervention_parameters(self): + """ + Unfreeze intervention parameters. + """ + _linked_key_set = set([]) + trainable_intervention_parameters = {} + for k, v in self.interventions.items(): + if isinstance(v[0], pv.TrainableIntervention): + if k in self._intervention_reverse_link: + if not self._intervention_reverse_link[k] in _linked_key_set: + _linked_key_set.add(self._intervention_reverse_link[k]) + for n, p in v[0].named_parameters(): + p.requires_grad = True + trainable_intervention_parameters[k+"#"+n] = p + else: + for n, p in v[0].named_parameters(): + p.requires_grad = True + trainable_intervention_parameters[k+"#"+n] = p + #for n, p in trainable_intervention_parameters.items(): + # print("Grad of " +n + " is ", p.grad, " value is ", p.data.norm()) + return trainable_intervention_parameters + diff --git a/requirements.txt b/requirements.txt index 9fbd513..5e04bd3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,17 @@ -torch>=2.0.0 +torch==2.0.1 +torchvision==0.15.2 # Removed flash-attn for now. # flash-attn>=2.5.6 --install-option='--no-build-isolation' pyvene>=0.1.2 transformers>=4.39.3 +# transformers==4.33.0 protobuf>=3.20.0 matplotlib>=3.7.4 ipywidgets>=8.1.1 plotnine>=0.12.4 huggingface-hub==0.23.0 -numpy>=1.26.4 +numpy>=0.26.4 +pycocotools accelerate>=0.29.1 sentencepiece>=0.1.96 evaluate>=0.4.1 @@ -22,3 +25,13 @@ ydata-profiling>=4.7.0 seaborn==0.12.2 # Colab notebook setup. gcsfs>=2024.2.0 +h5py +tqdm +numpy +pandas +ftfy +timm +pyyaml +sacrebleu +git+https://github.com/bckim92/language-evaluation.git +wget