-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference_overall.py
80 lines (63 loc) · 3.75 KB
/
inference_overall.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from vis_corrector_recap_w import Corrector
from types import SimpleNamespace
import argparse
import json
import gc
import transformers, torch, spacy, os
from tqdm import tqdm
from typing import Dict, List
from transformers import pipeline, Blip2Processor, Blip2ForConditionalGeneration
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Code for 'ReCaption'.")
parser.add_argument('--stage-1', default='./Each_stage_texts/mplug-Owl3/CODA_mPLUG_Owl3_detail_captions_neww.json')
parser.add_argument('--query', default='Describe this image.', type=str, help="text query for MLLM")
parser.add_argument('--cache-dir', type=str, default='./cache_dir')
parser.add_argument('--detector-config',
default='/home/fjq/MLLMs/Woodpecker/groundingdino/config/GroundingDINO_SwinT_OGC.py', type=str,
help="Path to the detector config, \
in the form of 'path/to/GroundingDINO_SwinT_OGC.py' ")
parser.add_argument('--detector-model', default='/data/fjq/3.VLMS/Woodpecker/groundingdino_swint_ogc.pth', type=str,
help="Path to the detector checkpoint, \
in the form of 'path/to/groundingdino_swint_ogc.pth' ")
args = parser.parse_args()
args_dict = {
'cache_dir': args.cache_dir,
'detector_config': args.detector_config,
'detector_model_path': args.detector_model,
}
model_args = SimpleNamespace(**args_dict)
pipeline = transformers.pipeline(
"text-generation", model='/data/lqf_llama/Meta-Llama-3-8B-Instruct',
model_kwargs={"torch_dtype": torch.float32},
device_map={"": 0}
)
######################################
model_blip = Blip2ForConditionalGeneration.from_pretrained('/data/fjq/blip-2/blip2-flan-t5-xxl',
torch_dtype=torch.float32)
model_blip.to("cuda:0")
processor_blip = Blip2Processor.from_pretrained('/data/fjq/blip-2/blip2-flan-t5-xxl')
######################################
model_instructblip = InstructBlipForConditionalGeneration.from_pretrained('/data/fjq/blip-2/Instructblip-flan-t5-xxl',
torch_dtype=torch.float32) # , load_in_8bit=True)
model_instructblip.to("cuda:0")
processor_instructblip = InstructBlipProcessor.from_pretrained('/data/fjq/blip-2/Instructblip-flan-t5-xxl')
######################################
model_instructblip_vicuna = InstructBlipForConditionalGeneration.from_pretrained('/data/fjq/blip-2/Instructblip-vicuna-13b',
torch_dtype=torch.float32) # , load_in_8bit=True)
model_instructblip_vicuna.to("cuda:0")
processor_instructblip_vicuna = InstructBlipProcessor.from_pretrained('/data/fjq/blip-2/Instructblip-vicuna-13b')
######################################
corrector = Corrector(model_args)
final_text = []
##所有的coda图片的详细描述,存到列表中
with open(args.stage_1, 'r', encoding='utf-8') as f:
coda_detail_captions = json.load(f)
coda_detail_captions_correct = []
for sample in coda_detail_captions:
output = corrector.correct(pipeline, processor_blip, model_blip, processor_instructblip, model_instructblip,
processor_insblip_vicuna, model_insblip_vicuna, sample)
print(output)
coda_detail_captions_correct.append(output)
print(len(coda_detail_captions_correct))
with open('./Each_stage_texts/mplug-Owl3/Ours_mplug_owl3_refined_caption.json', 'w', encoding='utf-8') as f:
json.dump(coda_detail_captions_correct, f, indent=4, ensure_ascii=False)