From 16e31a202f386d9274545739e0be391db7a86752 Mon Sep 17 00:00:00 2001 From: mlinmg Date: Mon, 4 Dec 2023 16:37:54 +0100 Subject: [PATCH] modified inference/translate.py to being able to send a list of text requests to batch the processing for better efficency --- .../inference/translator.py | 35 ++++++++++++++++--- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/src/seamless_communication/inference/translator.py b/src/seamless_communication/inference/translator.py index 57bea931..fd634aee 100644 --- a/src/seamless_communication/inference/translator.py +++ b/src/seamless_communication/inference/translator.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from enum import Enum, auto from pathlib import Path -from typing import List, Optional, Tuple, Union, cast +from typing import List, Optional, Tuple, Union, cast, Any import torch import torch.nn as nn @@ -20,6 +20,7 @@ from fairseq2.nn.padding import PaddingMask, get_seqs_and_padding_mask from fairseq2.typing import DataType, Device from torch import Tensor +from torch.nn.functional import pad from seamless_communication.inference.generator import ( SequenceGeneratorOptions, @@ -153,6 +154,15 @@ def __init__( ) self.vocoder.eval() + def batch_tensors(self, tensors: List[Tensor], pad_value: Any) -> Tensor: + padding_size = max(tensor.shape[0] for tensor in tensors) + dims = len(tensors[0].shape) + padded_tensors = [] + for tensor in tensors: + padding = [0] * 2 * dims + padding[-1] = padding_size - tensor.shape[0] + padded_tensors.append(pad(tensor, padding, "constant", pad_value)) + return torch.stack([tensor for tensor in padded_tensors], dim=0) @classmethod def get_prediction( cls, @@ -215,7 +225,7 @@ def get_modalities_from_task_str(task_str: str) -> Tuple[Modality, Modality]: @torch.inference_mode() def predict( self, - input: Union[str, Tensor, SequenceData], + input: Union[str, List[str], Tensor, SequenceData], task_str: str, tgt_lang: str, src_lang: Optional[str] = None, @@ -232,7 +242,7 @@ def predict( The main method used to perform inference on all tasks. :param input: - Either text or path to audio or audio Tensor. + Either text (or a list of text to be batched) or path to audio or audio Tensor. :param task_str: String representing the task. Valid choices are "S2ST", "S2TT", "T2ST", "T2TT", "ASR" @@ -267,6 +277,19 @@ def predict( if isinstance(input, dict): src = cast(SequenceData, input) + elif isinstance(input, list): + self.token_encoder = self.text_tokenizer.create_encoder( + task="translation", lang=src_lang, mode="source", device=self.device + ) + collated = [self.collate(self.token_encoder(text)) for text in input] + seqs = self.batch_tensors([item['seqs'].squeeze(0) for item in collated], self.text_tokenizer.vocab_info.pad_idx) + seq_lens = torch.cat([item['seq_lens'] for item in collated]) + src = {'seqs': seqs, 'seq_lens': seq_lens, 'is_ragged': True} + elif isinstance(input, str): + self.token_encoder = self.text_tokenizer.create_encoder( + task="translation", lang=src_lang, mode="source", device=self.device + ) + src = self.collate(self.token_encoder(input)) elif input_modality == Modality.SPEECH: audio = input if isinstance(audio, str): @@ -307,6 +330,7 @@ def predict( seqs, padding_mask = get_seqs_and_padding_mask(src) + if text_generation_opts is None: text_generation_opts = SequenceGeneratorOptions( beam_size=5, soft_max_seq_len=(1, 200) @@ -349,7 +373,7 @@ def predict( unit_generation_ngram_filtering=unit_generation_ngram_filtering, ) else: - assert isinstance(input, str) + #assert isinstance(input, str) Remove this assertion since input can be a list of strings src_texts = [input] @@ -382,7 +406,6 @@ def predict( return texts, None else: assert units is not None - if isinstance(self.model.t2u_model, UnitYT2UModel): # Remove the lang token for AR UnitY since the vocoder doesn't need it # in the unit sequence. tgt_lang is fed as an argument to the vocoder. @@ -426,3 +449,5 @@ def predict( sample_rate=sample_rate, ), ) + +