-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batching #48
Comments
import time
import torch
from transformers import AutoModelForCausalLM
from PIL import Image
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
from deepseek_vl.utils.io import load_pil_images
import concurrent.futures
# Initialize the model and processor
model_path = "deepseek-ai/deepseek-vl-1.3b-chat"
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
# Function to load and process images and text per thread
def process_conversation(conversation_piece):
# Load images directly using the entire conversation piece
pil_images = load_pil_images(conversation_piece)
prepare_inputs = vl_chat_processor(
conversations=conversation_piece,
images=pil_images,
force_batchify=True
).to(vl_gpt.device)
return prepare_inputs
n_threads = 8
conversation = [
[
{"role": "User", "content": "Thoroughly describe <image_placeholder>.", "images": ["../../man_wave.png"]},
{"role": "Assistant", "content": ""}
] for _ in range(n_threads)
]
start = time.time()
# Using ThreadPoolExecutor to parallelize image loading and input preparation
with concurrent.futures.ThreadPoolExecutor(max_workers=n_threads) as executor:
futures = [executor.submit(process_conversation, conv) for conv in conversation]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
print("Time for preprocessing: ", time.time() - start)
# Aggregate results from threads
input_ids = torch.cat([res.input_ids for res in results], dim=0)
pixel_values = torch.cat([res.pixel_values for res in results], dim=0)
attention_mask = torch.cat([res.attention_mask for res in results], dim=0)
images_seq_mask = torch.cat([res.images_seq_mask for res in results], dim=0)
images_emb_mask = torch.cat([res.images_emb_mask for res in results], dim=0)
sft_format = [res.sft_format for res in results]
# Run model to get the response
inputs_embeds = vl_gpt.prepare_inputs_embeds(
input_ids=input_ids,
pixel_values=pixel_values,
images_seq_mask=images_seq_mask,
images_emb_mask=images_emb_mask
)
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=1,
do_sample=False,
use_cache=True
)
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
print(sft_format[0], answer) # Assuming sft_format is consistent across threads
end = time.time()
print("Time taken to process: ", end - start) |
The |
Nope. It's just a toy demo, not for production purpose. |
I know, that's why I at least tried to quickly "parallelize" the processor |
You could use custom dataset class and use dataloader to do batching. This is how I run it and it is quite fast.
|
Oh wow that's a very cool approach, thank you so much for sharing it! I will try it out asap |
I'm also interested in making inference faster. |
Is this code "optimal" for batched inference and preprocessing?
The text was updated successfully, but these errors were encountered: