Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modified inference/translate.py to allow batch text translations #243

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions src/seamless_communication/inference/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import List, Optional, Tuple, Union, cast
from typing import List, Optional, Tuple, Union, cast, Any

import torch
import torch.nn as nn
Expand All @@ -20,6 +20,7 @@
from fairseq2.nn.padding import PaddingMask, get_seqs_and_padding_mask
from fairseq2.typing import DataType, Device
from torch import Tensor
from torch.nn.functional import pad

from seamless_communication.inference.generator import (
SequenceGeneratorOptions,
Expand Down Expand Up @@ -153,6 +154,15 @@ def __init__(
)
self.vocoder.eval()

def batch_tensors(self, tensors: List[Tensor], pad_value: Any) -> Tensor:
padding_size = max(tensor.shape[0] for tensor in tensors)
dims = len(tensors[0].shape)
padded_tensors = []
for tensor in tensors:
padding = [0] * 2 * dims
padding[-1] = padding_size - tensor.shape[0]
padded_tensors.append(pad(tensor, padding, "constant", pad_value))
return torch.stack([tensor for tensor in padded_tensors], dim=0)
@classmethod
def get_prediction(
cls,
Expand Down Expand Up @@ -215,7 +225,7 @@ def get_modalities_from_task_str(task_str: str) -> Tuple[Modality, Modality]:
@torch.inference_mode()
def predict(
self,
input: Union[str, Tensor, SequenceData],
input: Union[str, List[str], Tensor, SequenceData],
task_str: str,
tgt_lang: str,
src_lang: Optional[str] = None,
Expand All @@ -232,7 +242,7 @@ def predict(
The main method used to perform inference on all tasks.

:param input:
Either text or path to audio or audio Tensor.
Either text (or a list of text to be batched) or path to audio or audio Tensor.
:param task_str:
String representing the task.
Valid choices are "S2ST", "S2TT", "T2ST", "T2TT", "ASR"
Expand Down Expand Up @@ -267,6 +277,19 @@ def predict(

if isinstance(input, dict):
src = cast(SequenceData, input)
elif isinstance(input, list):
self.token_encoder = self.text_tokenizer.create_encoder(
task="translation", lang=src_lang, mode="source", device=self.device
)
collated = [self.collate(self.token_encoder(text)) for text in input]
seqs = self.batch_tensors([item['seqs'].squeeze(0) for item in collated], self.text_tokenizer.vocab_info.pad_idx)
seq_lens = torch.cat([item['seq_lens'] for item in collated])
src = {'seqs': seqs, 'seq_lens': seq_lens, 'is_ragged': True}
elif isinstance(input, str):
self.token_encoder = self.text_tokenizer.create_encoder(
task="translation", lang=src_lang, mode="source", device=self.device
)
src = self.collate(self.token_encoder(input))
elif input_modality == Modality.SPEECH:
audio = input
if isinstance(audio, str):
Expand Down Expand Up @@ -307,6 +330,7 @@ def predict(

seqs, padding_mask = get_seqs_and_padding_mask(src)


if text_generation_opts is None:
text_generation_opts = SequenceGeneratorOptions(
beam_size=5, soft_max_seq_len=(1, 200)
Expand Down Expand Up @@ -349,7 +373,7 @@ def predict(
unit_generation_ngram_filtering=unit_generation_ngram_filtering,
)
else:
assert isinstance(input, str)
#assert isinstance(input, str) Remove this assertion since input can be a list of strings

src_texts = [input]

Expand Down Expand Up @@ -382,7 +406,6 @@ def predict(
return texts, None
else:
assert units is not None

if isinstance(self.model.t2u_model, UnitYT2UModel):
# Remove the lang token for AR UnitY since the vocoder doesn't need it
# in the unit sequence. tgt_lang is fed as an argument to the vocoder.
Expand Down Expand Up @@ -426,3 +449,5 @@ def predict(
sample_rate=sample_rate,
),
)