From 2ee42c7ff20ddd2cf732990bb7602158f225ca77 Mon Sep 17 00:00:00 2001 From: juberti Date: Tue, 15 Oct 2024 12:34:50 -0700 Subject: [PATCH 01/17] v1 --- ultravox/data/datasets.py | 1015 ++++-------------- ultravox/training/config_base.py | 8 +- ultravox/training/configs/llama_whisper.yaml | 105 ++ ultravox/training/train.py | 13 +- 4 files changed, 340 insertions(+), 801 deletions(-) create mode 100644 ultravox/training/configs/llama_whisper.yaml diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index a0f63950..670db2e2 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -1,13 +1,13 @@ import abc import base64 import dataclasses -import enum import io import itertools import logging import os import tempfile import warnings +from contextlib import closing from enum import Enum from typing import Any, Callable, Dict, List, Optional, Sequence @@ -21,57 +21,94 @@ import torch import torch.nn.functional as F import transformers +from pydantic import BaseModel from torch.utils import data -from ultravox.data import dataset_config +import enum + from ultravox.data import text_proc +from ultravox.evaluation.eval_types import EvalConfig +from ultravox.utils import device_helpers SAMPLE_RATE = 16000 -TRANSCRIBE_INPUT_TASK = "transcribe_input" -TRANSCRIBE_OUTPUT_TASK = "transcribe_output" -ANSWER_TASK = "answer" - -TRANSCRIBE_PROMPTS = [ - # from Gazelle - "Transcribe\n<|audio|>", - "Transcribe exactly what is said here\n<|audio|>", - "Repeat exactly what is written here: <|audio|>", - "Write exactly what was said: <|audio|>", - "First listen to the clip. Then, transcribe exactly what is said. <|audio|>", - # from https://arxiv.org/pdf/2402.08846 - "Transcribe speech to text: <|audio|>", - # from GPT-4 - "Capture every word from the audio verbatim\n<|audio|>", - "Convert speech to text from audio\n<|audio|>", - "Listen and transcribe the complete text from audio\n<|audio|>", - "Record in writing what is spoken in audio\n<|audio|>", - "Transcribe the spoken words from audio with exact wording and punctuation\n<|audio|>", -] -ANSWER_PROMPTS = [ - # from Gazelle - "Listen to <|audio|> and respond to it", - "Listen and respond: <|audio|>", - "Respond to <|audio|>", - "Respond to the user <|audio|>", - "<|audio|>", - "<|audio|>", # repeated to emphasize not needing a prompt for Q&A tasks - "Respond to this question: \n<|audio|>", - "Continue the conversation after <|audio|>", - "First listen to the clip: <|audio|>\n How would you respond?", - "<|audio|> - respond", - "<|audio|>\n Respond to the question", -] - # TODO(juberti): set these in the environment so they don't need to be hard-coded here. os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "service_account.json" os.environ["GOOGLE_CLOUD_PROJECT"] = "fixie-training" - # Silence the spurious warnings coming from the MosaicML streaming library. logging.getLogger("streaming.base.dataset").setLevel(logging.ERROR) +class DatasetSplit(str, enum.Enum): + TRAIN = "train" + VALIDATION = "validation" + + +# Global arguments for voice datasets. +@dataclasses.dataclass +class VoiceDatasetArgs: + """Global arguments for voice datasets.""" + + batch_size: int = 4 + """Batch size for train, eval, or validation.""" + shuffle: bool = False + """Whether to shuffle the dataset.""" + shuffle_seed: int = 42 + """Seed for shuffling the dataset.""" + max_audio_duration_secs: Optional[float] = None + """Whether to skip samples with audio longer than this duration.""" + split: DatasetSplit = DatasetSplit.TRAIN + """Which split of the dataset to use.""" + + def __post_init__(self): + if isinstance(self.split, str): + self.split = DatasetSplit(self.split.lower()) + + +class DatasetSplitConfig(BaseModel): + name: str + """Name of the split""" + num_samples: int + """Number of samples in the split""" + is_validation: bool = False + + def __post_init__(self): + if self.name == "validation": + self.is_validation = True + + +class DatasetConfig(BaseModel): + path: str = "" + """Directory of the dataset, or huggingface dataset name; must be set for "generic" datasets. If not set, it is automatically inferred for predefined dataset types.""" + subset: Optional[str] = None + """Name of the dataset, or huggingface dataset config/subset name""" + splits: List[DatasetSplit] = [] + """List of splits to use, e.g. [{"name": "train", "num_samples": 1000}, {"name": "validation", "num_samples": 100}]""" + user_template: str = "<|audio|>" + """Template for the user's message""" + user_template_args: Dict[str, str] = {} + """Optional arguments (e.g., target language) for the user template""" + assistant_template: str = "{{text}}" + """Template for the assistant's message""" + transcript_template: str = "{{text}}" + """Template for the transcript""" + audio_field: Optional[str] = "audio" + """Field in the dataset that contains the audio, use None if the dataset does not contain audio""" + use_mds: bool = False + """Set to True to load the dataset from GCP (using MDS) instead of Hugging Face""" + mds_batch_size: int = 32 + """Batch size for MDS""" + + class Config: + extra = "forbid" + # do not allow undefined parameters + + def model_post_init(self, __context: Any) -> None: + if not self.splits: + raise ValueError("At least one split must be provided") + + @dataclasses.dataclass class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq): # when enabled, the alt_input_ids, alt_attention_mask, and alt_labels fields are used for computing the KL loss in UltravoxModel @@ -213,42 +250,6 @@ def add_past_messages(self, past_messages: List[Dict[str, str]]): """For evaluations, the known transcript of the audio.""" -class DatasetSplit(str, enum.Enum): - TRAIN = "train" - VALIDATION = "validation" - - -@dataclasses.dataclass -class VoiceDatasetArgs: - data_dir: Optional[str] = None - prompt: Optional[str] = None - """A specific prompt to use for the dataset.""" - num_prompts: int = 1 - """If `prompt` is not set, the number of canned prompts to use.""" - include_audio: bool = True - """Whether to include audio in the samples.""" - include_context: bool = True - """Whether to include additional textual context from the dataset to the prompt.""" - max_context_length: int = 1500 - """Maximum length of context to include in the prompt. Otherwise, skip the sample.""" - shuffle: bool = False - """Whether to shuffle the dataset.""" - shuffle_seed: int = 42 - """Seed for shuffling the dataset.""" - max_audio_duration_secs: Optional[float] = None - """Whether to skip samples with audio longer than this duration.""" - use_mds: bool = False - """Whether to load the dataset from GCP (using MDS) or Hugging Face.""" - mds_batch_size: int = 32 - """Batch size for MDS.""" - split: DatasetSplit = DatasetSplit.TRAIN - """Which split of the dataset to use.""" - - def __post_init__(self): - if isinstance(self.split, str): - self.split = DatasetSplit(self.split.lower()) - - def _get_messages( *turns: str, sys_prompt: Optional[str] = None, assistant_last: bool = True ) -> List[Dict[str, str]]: @@ -289,23 +290,21 @@ class VoiceDataset(SizedIterableDataset): Wraps a Hugging Face dataset or MDS-formatted dataset from GCP. """ - BASE_AUDIO_COLUMNS = ["audio"] - def __init__(self, args: VoiceDatasetArgs) -> None: super().__init__() self._args = args - self._session: Optional[requests.Session] = None self._rng = np.random.default_rng(self._args.shuffle_seed) - self._weight = 1.0 # the default weight for the dataset + if True: # device_helpers.get_local_rank() == 0: + logging.info( + f"Created VoiceDataset with config:\n{self._config.model_dump_json(indent=2)}" + ) - def _init_dataset(self, dataset: data.Dataset, estimated_length: int = 1) -> None: + def _init_dataset(self, dataset: data.Dataset, num_samples: int) -> None: self._dataset = dataset - # Only required when using epochs when training dataset. - self._estimated_length = estimated_length + self._length = num_samples - @property - def weight(self) -> float: - return self._weight + def __len__(self): + return self._length def _load_audio_dataset( self, @@ -316,7 +315,6 @@ def _load_audio_dataset( shuffle: Optional[bool] = None, streaming: bool = True, ) -> data.Dataset: - logging.info(f"Loading dataset {path} {name} {split} {shuffle} {streaming}") if shuffle is None: shuffle = self._args.shuffle if self._args.use_mds: @@ -337,13 +335,15 @@ def _load_audio_dataset( shuffle_seed=self._args.shuffle_seed, ) else: + # HF datasets sometimes fails to download due to network issues, so retry a few times. dataset = datasets.load_dataset( - path, name, split=split, trust_remote_code=True, streaming=streaming + path, + name, + split=split, + trust_remote_code=True, + streaming=streaming, + download_config=datasets.DownloadConfig(max_retries=10), ) - for column_name in self.BASE_AUDIO_COLUMNS: - dataset = dataset.cast_column( - column_name, datasets.Audio(sampling_rate=SAMPLE_RATE) - ) if shuffle: dataset = dataset.shuffle(seed=self._args.shuffle_seed) return dataset @@ -352,30 +352,42 @@ def __iter__(self): actual_length = 0 for _, row in enumerate(self._dataset): sample = self._get_sample(row) - if sample is not None: + if sample is None: + raise ValueError( + f"Sample is None in dataset {self._config.alias} for row {row}" + ) + + if self._config.audio_field is not None: + # If audio_field is set, make sure the sample has audio data. + if sample.audio is None: + raise ValueError( + f"Audio field ({self._config.audio_field}) is None in dataset {self._config.alias} for sample {sample}" + ) + if sample.audio.shape[-1] == 0: + raise ValueError( + f"Audio length is 0 in dataset {self._config.alias} for sample {sample}" + ) if ( - self._args.max_audio_duration_secs is None - or sample.audio is None - or sample.audio.shape[-1] / SAMPLE_RATE - <= self._args.max_audio_duration_secs + self._args.max_audio_duration_secs is not None + and sample.audio.shape[-1] / SAMPLE_RATE + > self._args.max_audio_duration_secs ): - yield sample + warnings.warn( + f"Audio length ({sample.audio.shape[-1] / SAMPLE_RATE}s) exceeds max audio duration ({self._args.max_audio_duration_secs}s) in dataset {self._config.alias}, skipping sample." + ) + continue + + yield sample actual_length += 1 - # If len(dataset) == 0 most likely the dataset is a validation dataset, - # or the training is using max_steps instead of num_epochs. - if actual_length > len(self) and len(self) > 1: + if actual_length == len(self) + 1: warnings.warn( - f"The estimated length {self._estimated_length} has been exceeded for type {type(self._dataset)}. Make sure to update." + f"The presumed length {self._length} has been exceeded for dataset {self._config.alias}. Make sure to update." ) - - if actual_length != len(self) and len(self) > 1: + if actual_length != len(self): warnings.warn( - f"Mismatch between estimated length ({self._estimated_length}) and actual length ({actual_length}) for dataset of type {type(self._dataset)}. Make sure to update." + f"Mismatch between presumed length ({self._length}) and actual length ({actual_length}) for dataset {self._config.alias}. Make sure to update." ) - def __len__(self): - return self._estimated_length - @abc.abstractmethod def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: """ @@ -383,36 +395,10 @@ def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: Returns None if the sample should be skipped. """ - def _choice(self, prompts: List[str]) -> str: - return self._rng.choice(prompts[: self._args.num_prompts]) - - def _get_answer_prompt(self) -> str: - if self._args.prompt: - return self._args.prompt - return self._choice(ANSWER_PROMPTS) - - def _get_transcribe_prompt(self) -> str: - if self._args.prompt: - return self._args.prompt - return self._choice(TRANSCRIBE_PROMPTS) - - def _get_answer_messages( - self, question: str, answer: str, context: Optional[str] = None - ) -> List[Dict[str, str]]: - prompt = self._get_answer_prompt() if self._args.include_audio else question - user_content = f"{context}\n\n{prompt}" if context else prompt - return _get_messages(user_content, answer) - - def _get_transcribe_messages(self, text: str) -> List[Dict[str, str]]: - prompt = self._get_transcribe_prompt() - if not self._args.include_audio: - prompt = prompt.replace("<|audio|>", text) - return _get_messages(prompt, text) - def _get_audio( - self, row: transformers.BatchFeature, column_name: str = "audio" + self, row: transformers.BatchFeature, column_name: Optional[str] = "audio" ) -> np.ndarray: - if column_name not in self.BASE_AUDIO_COLUMNS: + if column_name not in self._config.base_audio_columns: raise ValueError( f"Unknown audio column: {column_name}. This is likely a bug and the audio might not be resampled to {SAMPLE_RATE} Hz." ) @@ -430,34 +416,30 @@ def _get_audio( assert sampling_rate == SAMPLE_RATE return audio - def _load_audio(self, base_url: str, folder: str, filename: str) -> np.ndarray: - if self._args.data_dir: - audio_path = f"{self._args.data_dir}/{folder}/{filename}" - audio = audio_from_file(audio_path) - else: + def _load_audio( + self, base_url: Optional[str], data_dir: Optional[str], filename: str + ) -> np.ndarray: + if base_url is not None: url = f"{base_url}/{filename}" # hack for GCS bucket naming - if self._session is None: - self._session = requests.Session() - response = self._session.get(url) - response.raise_for_status() - audio = audio_from_buf(response.content) - return audio - - def _get_transcribe_sample( - self, - row: transformers.BatchFeature, - tcol: str = "text", - tproc: Optional[Callable[[str], str]] = None, - ) -> Optional[VoiceSample]: - try: - text = tproc(row[tcol]) if tproc else row[tcol] - except text_proc.FormatASRError: - return None - return self._make_sample( - self._get_transcribe_messages(text), - self._get_audio(row), - audio_transcript=text, - ) + try: + with closing(requests.Session()) as session: + response = session.get(url) + response.raise_for_status() + return audio_from_buf(response.content) + except requests.RequestException as e: + raise ValueError( + f"Failed to load audio from URL: {url}. Error: {str(e)}" + ) + elif data_dir is not None: + audio_path = os.path.join(data_dir, filename) + try: + return audio_from_file(audio_path) + except IOError as e: + raise ValueError( + f"Failed to load audio file: {audio_path}. Error: {str(e)}" + ) + else: + raise ValueError("Either base_url or data_dir must be provided") def _make_sample( self, @@ -465,659 +447,109 @@ def _make_sample( audio: np.ndarray, audio_transcript: Optional[str] = None, ) -> VoiceSample: - if not self._args.include_audio: + if self._config.audio_field is None: return VoiceSample(messages) return VoiceSample(messages, audio, audio_transcript=audio_transcript) -class LibriSpeechDummyDataset(VoiceDataset): - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "hf-internal-testing/librispeech_asr_dummy", - "clean", - split="validation", - streaming=False, # not supported by the dummy dataset - ) - self._init_dataset(dataset, 73) - - def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: - return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text) - - -# Making EmptyDataset a SizedIterableDataset to be compatible with using epochs during training. -class EmptyDataset(SizedIterableDataset): - def __init__(self, estimated_length: int = 1) -> None: - self._estimated_length = estimated_length - - def __iter__(self): - return iter([]) - - def __len__(self): - return self._estimated_length - - -class AnyInstructDataset(VoiceDataset): - """ - Metadata file format: - {"chat": [ - {"role": "USER", "message": "Write a sentence based on this summary: iraqi embassy in jakarta removes saddam hussein 's photo", "speech": "chunk_00000/0001.mp3"}, - {"role": "AnyGPT", "message": "The building in Jakarta where people from Iraq work, took down a picture of a man named Saddam Hussein.", "speech": "chunk_00000/0002.mp3"} - ]} - """ - - def __init__(self, args: VoiceDatasetArgs) -> None: - # TODO(juberti): convert to MDS - super().__init__(args) - dataset = datasets.load_dataset( - "json", - "anyinstruct", - data_files="https://huggingface.co/datasets/fnlp/AnyInstruct/resolve/main/speech_conv/metadata.jsonl", - split="train", - ) - dataset = dataset.train_test_split( - test_size=0.01, seed=args.shuffle_seed, shuffle=True - ) - dataset = dataset["train" if args.split == DatasetSplit.TRAIN else "test"] - # TODO: make num_shards configurable if need be - dataset = dataset.to_iterable_dataset(num_shards=16) - if args.shuffle: - dataset = dataset.shuffle(seed=args.shuffle_seed) - self._init_dataset(dataset) - - def _load_anyinstruct_audio(self, filename: str): - return super()._load_audio( - "https://storage.googleapis.com/train-anyinstruct-speechconv-v1", - "anyinstruct/speech", - filename, - ) - - -class AnyInstructAnswerDataset(AnyInstructDataset): - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - - def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: - chat = row["chat"] - return self._make_sample( - self._get_answer_messages(chat[0]["message"], chat[1]["message"]), - self._load_anyinstruct_audio(chat[0]["speech"]), - audio_transcript=chat[0]["message"], - ) - - -class AnyInstructInputDataset(AnyInstructDataset): - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - - def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: - audio_transcript = row["chat"][0]["message"] - return self._make_sample( - self._get_transcribe_messages(audio_transcript), - self._load_anyinstruct_audio(row["chat"][0]["speech"]), - audio_transcript=audio_transcript, - ) - - -class AnyInstructOutputDataset(AnyInstructDataset): - def __init__(self, args: VoiceDatasetArgs) -> None: +class GenericDataset(VoiceDataset): + def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: super().__init__(args) - - def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: - audio_transcript = row["chat"][1]["message"] - return self._make_sample( - self._get_transcribe_messages(audio_transcript), - self._load_anyinstruct_audio(row["chat"][1]["speech"]), - audio_transcript=audio_transcript, - ) - - -class BoolQDataset(VoiceDataset): - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "fixie-ai/boolq-audio", split=args.split.value - ) - self._init_dataset(dataset) - - def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: - question = row["question"] - answer = "True" if row["answer"] else "False" - context = row["passage"] if self._args.include_context else None - return self._make_sample( - self._get_answer_messages(question, answer, context), - self._get_audio(row), - audio_transcript=row["question"], - ) - - -class BoolQInputDataset(BoolQDataset): - def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: - return self._get_transcribe_sample(row, tcol="question") - - -class QAVoiceDatasetMixin(VoiceDataset): - SEPARATORS = ["\n\n", "\n", "\n----\n"] - QUERY_PREFIX = ["Question: ", "Question:\n", "Q: ", "Q:\n", "Query: ", "Query:\n"] - CONTEXT_PREFIX = [ - "Passage: ", - "Passage:\n", - "Context: ", - "Context:\n", - "Background: ", - "Background:\n", - ] - ANSWER_PREFIX = [ - "Answer: ", - "A: ", - "", - "The answer is: ", - "Result: ", - "Conclusion: ", - ] - # In most cases there is no extra prompt-suffix needed - PROMPT_SUFFIXES = [""] - - # TODO: combine `_get_query_prompt` and `_get_answer_messages` into a single method - # and use this mixin for all non-ASR datasets. - def _get_query_prompt(self, question_str: str, context: str) -> Optional[str]: - """ - Creates a random prompt for a QA sample with a passage and question. - - Example prompt: - Passage: {context} - Question: {question} - {optional-prompt-suffix} - """ - if len(context) > self._args.max_context_length: - # Skip samples with long context - return None - - if self._args.prompt: - prompt = self._args.prompt - else: - prompt = self._choice(self.PROMPT_SUFFIXES) - - # Separate either with 1 or 2 newlines - separator = self._choice(self.SEPARATORS) - - query_prompt = self._choice(self.QUERY_PREFIX) - question = "<|audio|>" if self._args.include_audio else question_str - prompt = f"{query_prompt}{question}{separator}{prompt}" - - if self._args.include_context: - context_prompt = self._choice(self.CONTEXT_PREFIX) - prompt = f"{context_prompt}{context}{separator}{prompt}" - - return prompt.strip() - - -class BoolQWithExtendedAnswerDataset(BoolQDataset, QAVoiceDatasetMixin): - """ - A version of BoolQ that includes the context in the prompt and a longer explanation in the answer. - """ - - PROMPT_SUFFIXES = [ - "Provide a short explanation, then respond with True/False on the last line", - "Explain briefly, concluding with True/False on a new line." - "Write a quick explanation, and finish with True/False on the last line" - "Summarize in a few words, and end with True/False on a new line." - "Give a brief explanation first, then answer with True/False on the final line", - "Start with a concise explanation, and end with a True/False response on the last line.", - "Explain briefly and follow up with True/False at the end", - "Write a short explanation, then state True/False on a new line.", - "First, offer a brief explanation, and then reply with True/False at the end.", - "Present a concise explanation, ending with a True/False answer on the final line", - "Start with a brief explanation, and then answer with True/False at the end.", - ] - - def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: - """ - Example conversation: - <|user|> Passage: {context} - Question: {question} - Provide a short explanation, then respond with True/False on the last line - <|assistant|> {short_explanation} - Answer: {answer} - """ - answer = "True" if row["answer"] else "False" - answer_prompt = self._choice(self.ANSWER_PREFIX) - user_message = self._get_query_prompt( - question_str=row["question"], context=row["passage"] - ) - if user_message is None: - # Skips samples with long context - return None - - messages = _get_messages( - user_message, f"{row['explanation']}\n{answer_prompt}{answer}" - ) - - return self._make_sample( - messages, self._get_audio(row), audio_transcript=row["question"] - ) - - -class HeySQuADHumanDataset(QAVoiceDatasetMixin): - """ - HeySQuAD is a large-scale Spoken Question Answering (SQA) dataset which includes 76k human-spoken questions, - 97k machine-generated questions, and their corresponding textual answers from the SQuAD QA dataset. - https://arxiv.org/abs/2304.13689 - - This dataset is the human-spoken version of HeySQuAD. - """ - - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "fixie-ai/HeySQuAD_human", split=args.split.value - ) - self._init_dataset(dataset) - - def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: - """ - Example conversation - <|user|> Context: {context} - Question: {question} - <|assistant|> {answer} - """ - if row["is_impossible"] or not row["answers"]: - # Skip samples with no answer - return None - - prompt = self._get_query_prompt( - question_str=row["question"], context=row["context"] - ) - if prompt is None: - # Skips samples with long context - return None - - messages = _get_messages(prompt, row["answers"][0]["text"]) - return self._make_sample( - messages, self._get_audio(row), audio_transcript=row["question"] - ) - - -class SlueSQA5Dataset(QAVoiceDatasetMixin): - """ - SLUE-SQA-5 Dataset contains question texts, question audio, answer text, document text, and document audio from these datasets: - * SQuAD1.1 (for questions whose question_id starts with 'squad-') - * Natural Questions (for questions whose question_id starts with 'nq-') - * TriviaQA (for questions whose question_id starts with 'triviaqa-') - The following datasets are supposed to be included, but I haven't found them everywhere: - * WebQuestions (for questions whose question_id starts with 'wq-') - * CuratedTREC (for questions whose question_id starts with 'trec-') - * Spoken Wikipedia - - - Splits: train, validation, test, verified_test - """ - - BASE_AUDIO_COLUMNS = ["question_audio", "document_audio"] - - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "asapp/slue-phase-2", "sqa5", split=args.split.value - ) - self._init_dataset(dataset) - - def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: - """ - Example conversation - <|user|> Context: {context} - Question: {question} - <|assistant|> {answer} - """ - prompt = self._get_query_prompt( - question_str=row["raw_question_text"], context=row["raw_document_text"] - ) - if prompt is None: - # Skips samples with long context - return None - - messages = _get_messages(prompt, row["answer_spans"]["answer"][0]) - return self._make_sample( - messages, - self._get_audio(row, "question_audio"), - audio_transcript=row["raw_question_text"], - ) - - -# TODO: this dataset can be replaced with GenericVoiceDataset and will be removed/updated in the future. -class LibriSpeechDataset(VoiceDataset): - """ - LibriSpeech is a corpus of approximately 1000 hours of 16kHz read - English speech. The data is derived from read audiobooks from the - LibriVox project. A simple automatic procedure was used to select - the audio in the first two sets to be, on average, of higher - recording quality and with accents closer to US English. - https://huggingface.co/datasets/librispeech_asr - """ - - def __init__(self, args: VoiceDatasetArgs) -> None: - # TODO(juberti): convert to MDS, in a way that preserves the same - # concatenation of the three splits. MDS can interleave but not - # concatenate, it seems. - super().__init__(args) - ds: Any - if args.split == DatasetSplit.VALIDATION: - ds = self._load_audio_dataset("librispeech_asr", split="validation.clean") - else: - splits = ["train.clean.100", "train.clean.360", "train.other.500"] - ds = datasets.concatenate_datasets( - [ - self._load_audio_dataset("librispeech_asr", split=s, shuffle=False) - for s in splits - ] + self._config = config + split_names = [ + split.name + for split in config.splits + if split.is_validation == (self._args.split == DatasetSplit.VALIDATION) + ] + dsets = [] + total_samples = 0 + for split_name in split_names: + ds = self._load_audio_dataset(config.path, config.name, split=split_name) + ds = ds.cast_column( + config.audio_field, datasets.Audio(sampling_rate=SAMPLE_RATE) ) - if self._args.shuffle: - ds = ds.shuffle(seed=self._args.shuffle_seed) - self._init_dataset(ds) - - def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: - return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text) - - -# TODO: this dataset can be replaced with GenericVoiceDataset and will be removed/updated in the future. -class GigaSpeechDataset(VoiceDataset): - """ - GigaSpeech is an evolving, multi-domain English speech recognition corpus - with 10,000 hours of high quality labeled audio suitable for supervised training. - "s" split is 250 hours. Non-commercial use only. - https://huggingface.co/datasets/speechcolab/gigaspeech - """ - - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "speechcolab/gigaspeech", "xl", split=args.split.value - ) - self._init_dataset(dataset) - - def _get_sample(self, row) -> Optional[VoiceSample]: - return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text) - - -# TODO: this dataset can be replaced with GenericVoiceDataset and will be removed/updated in the future. -class VoxPopuliDataset(VoiceDataset): - """ - VoxPopuli is a large-scale multilingual speech corpus for representation learning, - semi-supervised learning and interpretation. - "en" split is 543 hours. - https://huggingface.co/datasets/facebook/voxpopuli - """ - - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "facebook/voxpopuli", "en", split=args.split.value - ) - self._init_dataset(dataset) - - def _get_sample(self, row) -> Optional[VoiceSample]: - return self._get_transcribe_sample(row, tcol="raw_text") - - -# TODO: this dataset can be replaced with GenericVoiceDataset and will be removed/updated in the future. -class CommonVoiceDataset(VoiceDataset): - """ - The Common Voice dataset consists of a unique MP3 and corresponding text file - https://huggingface.co/datasets/mozilla-foundation/common_voice_16_1 - Dataset({ - features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'], - num_rows: 1090061 - }) - NOTE: requires HF login - """ - - def __init__(self, args: VoiceDatasetArgs, lang: str = "en") -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "mozilla-foundation/common_voice_16_1", lang, split=args.split.value - ) - self._init_dataset(dataset) + dsets.append(ds) + total_samples += len(ds) + dataset = datasets.concatenate_datasets(dsets) + super()._init_dataset(dataset, total_samples) def _get_sample(self, row) -> Optional[VoiceSample]: - return self._get_transcribe_sample(row, tcol="sentence") - - -# TODO: this dataset can be replaced with GenericVoiceDataset and will be removed/updated in the future. -class CoVoST2Dataset(VoiceDataset): - """ - CoVoST 2 is a large-scale multilingual speech translation corpus covering translations from 21 languages into English - and from English into 15 languages. The dataset is created using Mozilla's open-source Common Voice 4 database of - crowdsourced voice recordings. There are 2,900 hours of speech represented in the corpus. - - The original Hugging Face dataset link: https://huggingface.co/datasets/facebook/covost2 - Since this dataset requires audio files to be downloaded separately, a new dataset is created with the audio files: - https://huggingface.co/datasets/fixie-ai/covost2 - - Due to the scale of the dataset and the audio files being repeated, only a portion of the dataset was converted. - See [this issue](https://github.com/fixie-ai/ultravox/issues/50) for more information. - - Supported subsets (En -> X): - 'en_de', 'en_tr', 'en_fa', 'en_sv-SE', 'en_mn', 'en_zh-CN', 'en_cy', - 'en_ca', 'en_sl', 'en_et', 'en_id', 'en_ar', 'en_ta', 'en_lv', 'en_ja' - Supported subsets (X -> En): - 'fr_en', 'zh-CN_en', 'es_en' - """ - - CODE_TO_LANG = { - "en": "English", - "de": "German", - "tr": "Turkish", - "fa": "Persian", - "sv-SE": "Swedish", - "mn": "Mongolian", - "zh-CN": "Chinese", - "cy": "Welsh", - "ca": "Catalan", - "sl": "Slovenian", - "et": "Estonian", - "id": "Indonesian", - "ar": "Arabic", - "ta": "Tamil", - "lv": "Latvian", - "ja": "Japanese", - "fr": "French", - "es": "Spanish", - } - - # We currently don't use this dataset for training, so mainly the first prompt it ever used. - # The "no explanation" part is important, specially for evaluations, but it's not repeated - # in all prompts to avoid being too repetitive in training. - TRANSLATE_PROMPTS = [ - "Translate the following into {target}, without any explanation: <|audio|>", - "Translate the following into {target} language, no explanation needed: <|audio|>", - "Please convert the following into {target}. Be concise.\n<|audio|>", - "Could you translate this to {target} language? No commentary necessary.\n<|audio|>", - "Translate the text below to {target}.\n<|audio|>", - "Translate the subsequent text into {target} language. <|audio|>", - "Can you translate this into the {target} language?\n<|audio|>", - "Transform the following to {target}: <|audio|>", - ] - - def __init__(self, args: VoiceDatasetArgs, subset: str) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "fixie-ai/covost2", subset, split=args.split.value - ) - langs = subset.split("_") - assert len(langs) == 2, f"Invalid subset: {subset}" - self.source_lang = self.CODE_TO_LANG[langs[0]] - self.target_lang = self.CODE_TO_LANG[langs[1]] - self._init_dataset(dataset) - - def _get_sample(self, row) -> VoiceSample: - prompt = self._choice(self.TRANSLATE_PROMPTS).format(target=self.target_lang) - - transcript = row["sentence"] - translation = row["translation"] - if not self._args.include_audio: - prompt = prompt.replace("<|audio|>", transcript) - - return self._make_sample( - _get_messages(prompt, translation), - self._get_audio(row), - audio_transcript=transcript, - ) - - -# TODO: this dataset can be replaced with GenericVoiceDataset and will be removed/updated in the future. -class PeopleSpeechDataset(VoiceDataset): - """ - The People's Speech Dataset is among the world's largest English speech - recognition corpus. It includes 30,000+ hours of transcribed speech in - English languages with a diverse set of speakers. - https://huggingface.co/datasets/MLCommons/peoples_speech - """ - - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "fixie-ai/peoples_speech", "clean", split=args.split.value - ) - self._init_dataset(dataset) - - def _get_sample(self, row) -> Optional[VoiceSample]: - return self._get_transcribe_sample(row, tcol="text") - - -class SodaDataset(VoiceDataset): - BASE_AUDIO_COLUMNS = ["audio_second_last_turn"] - - SYS_PROMPTS = [ - "Follow the flow of the conversation and respond just like a human would in the same situation.", - "Engage in the conversation naturally, responding as a human would.", - "Follow the dialogue and reply like a person in that situation.", - "Participate in the chat and answer as if you were a human.", - "Interact smoothly and respond just like a person would.", - "Stay in the moment and reply as a human would in the conversation.", - "Flow with the discussion and respond naturally, as a person would.", - "Keep the dialogue going and answer like a human would.", - "Follow along and reply in a way a person would in the chat.", - "Stay engaged in the conversation and respond like a human.", - "Maintain the flow of the chat and answer just as a person would.", - ] - - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "fixie-ai/soda-audio", split=args.split.value - ) - self._init_dataset(dataset) - - def _get_sample(self, row) -> VoiceSample: - turns = row["dialogue"] - # Make sure the last turn is the assistant's - roles = ["user", "assistant"] if len(turns) % 2 == 0 else ["assistant", "user"] - - sys_prompt = self._choice(self.SYS_PROMPTS) - - messages = _get_messages(*turns[:-1], sys_prompt=sys_prompt) - - messages[-1]["content"] = row["alt_last_turn"] - if self._args.include_audio: - messages[-2]["content"] = "<|audio|>" - - return self._make_sample( - messages, - audio=self._get_audio(row, "audio_second_last_turn"), - audio_transcript=turns[-2], - ) - - -class GenericVoiceDataset(VoiceDataset): - def __init__( - self, args: VoiceDatasetArgs, config: dataset_config.DataDictConfig - ) -> None: - super().__init__(args) - dataset = datasets.concatenate_datasets( - [ - self._load_audio_dataset( - config.path, - name=config.name, - split=s, - streaming=config.streaming, - shuffle=False, - ) - for s in config.splits - ] - ) - # shuffling is only supported on huggingface datasets for now, not MDS - if self._args.shuffle: - dataset = dataset.shuffle(seed=self._args.shuffle_seed) - - if config.num_samples: - dataset = Range(dataset, config.num_samples, config.total_samples) - - self._weight = config.weight - - self.user_template = config.user_template - self.assistant_template = config.assistant_template - self.transcript_template = config.transcript_template - - super()._init_dataset(dataset, config.total_samples) - - def _get_sample(self, row) -> VoiceSample: try: user_content = jinja2.Template( - self.user_template, undefined=jinja2.StrictUndefined - ).render(**row, text_proc=text_proc, dataset=self) + self._config.user_template, undefined=jinja2.StrictUndefined + ).render( + **row, + text_proc=text_proc, + dataset=self, + **self._config.user_template_args, + ) assistant_content = jinja2.Template( - self.assistant_template, undefined=jinja2.StrictUndefined + self._config.assistant_template, undefined=jinja2.StrictUndefined ).render(**row, text_proc=text_proc, dataset=self) transcript = jinja2.Template( - self.transcript_template, undefined=jinja2.StrictUndefined + self._config.transcript_template, undefined=jinja2.StrictUndefined ).render(**row, text_proc=text_proc, dataset=self) except jinja2.TemplateError as e: print(f"Error rendering template: {e}") - print(f"user_template: {self.user_template}") - print(f"assistant_template: {self.assistant_template}") - print(f"transcript_template: {self.transcript_template}") + print(f"user_template: {self._config.user_template}") + print(f"assistant_template: {self._config.assistant_template}") + print(f"transcript_template: {self._config.transcript_template}") print(f"sample keys: {list(row.keys())}") raise ValueError( - f"Template rendering failed. Make sure all keys in the template exist in the sample." + "Template rendering failed. Make sure all keys in the template exist in the sample." ) from e return self._make_sample( _get_messages(user_content, assistant_content), - self._get_audio(row), + self._get_audio(row, self._config.audio_field), audio_transcript=transcript, ) -def create_dataset(name: str, args: VoiceDatasetArgs) -> SizedIterableDataset: - DATASET_MAP: Dict[str, Any] = { - "anyinstruct": AnyInstructAnswerDataset, - "anyinstruct_in": AnyInstructInputDataset, - "anyinstruct_out": AnyInstructOutputDataset, - "boolq": BoolQDataset, - "boolq_in": BoolQInputDataset, - "boolq_extended": BoolQWithExtendedAnswerDataset, - "heysquad_human": HeySQuADHumanDataset, - "slue_sqa5": SlueSQA5Dataset, - "gigaspeech": GigaSpeechDataset, - "librispeech": LibriSpeechDataset, - "voxpopuli": VoxPopuliDataset, - "commonvoice": CommonVoiceDataset, - "covost2": CoVoST2Dataset, - "peoplespeech": PeopleSpeechDataset, - "soda": SodaDataset, - "dummy": LibriSpeechDummyDataset, - } - if isinstance(name, dataset_config.DataDictConfig): - return GenericVoiceDataset(args, name) - else: - name, *ext = name.split(":") - return DATASET_MAP[name](args, *ext) +# Making EmptyDataset a SizedIterableDataset to be compatible with using epochs during training. +class EmptyDataset(SizedIterableDataset): + def __iter__(self): + return iter([]) + + def __len__(self): + return 0 + + +DATASET_MAP: Dict[str, Any] = {} + + +def register_datasets(datasets: Dict): + for dataset in datasets: + DATASET_MAP[dataset] = create_dataset(dataset, datasets[dataset]) + + +def create_dataset( + args: VoiceDatasetArgs, config: DatasetConfig +) -> SizedIterableDataset: + configs = [] + while True: + configs.append(config) + base = config.get("base") + if not base: + break + config = base + merged_config = configs[-1] + for config in configs[:-1]: + merged_config.update(config) + del merged_config["base"] + return GenericDataset(args, merged_config) class StopStrategy(str, Enum): - FIRST_EXHAUSTED = "first_exhausted" - LAST_EXHAUSTED = "last_exhausted" - NEVER_STOP = "never_stop" + FIRST_EXHAUSTED = "FIRST_EXHAUSTED" + LAST_EXHAUSTED = "LAST_EXHAUSTED" + NEVER_STOP = "NEVER_STOP" + + +@dataclasses.dataclass +class DatasetAndWeight: + dataset: SizedIterableDataset + weight: float class InterleaveDataset(SizedIterableDataset): @@ -1125,7 +557,7 @@ class InterleaveDataset(SizedIterableDataset): def __init__( self, - datasets: Sequence[SizedIterableDataset], + datasets: Sequence[DatasetAndWeight], stop_strategy: StopStrategy = StopStrategy.LAST_EXHAUSTED, seed: Optional[int] = 42, static: bool = False, @@ -1137,13 +569,12 @@ def __init__( seed: Optional seed for reproducibility. static: If true, the datasets are interleaved in a static order with equal weights. """ - self._datasets = datasets + self._datasets = [dataset for dataset, _ in datasets] self._rng = np.random.default_rng(seed) self._static = static - self._stop_strategy = stop_strategy - weights = [getattr(ds, "weight", 1) for ds in datasets] + weights = [weight for _, weight in datasets] total_weight = sum(weights) self._normalized_probs = [w / total_weight for w in weights] @@ -1235,8 +666,4 @@ def __iter__(self): yield sample def __len__(self): - return ( - self._num_samples - if self._num_samples is not None - else self._estimated_length - ) + return self._length diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 9ba67054..ce67a0f1 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -17,8 +17,12 @@ @dataclasses.dataclass class TrainConfig: - data_sets: List[str] - val_sets: List[str] + # data-defined datasets + datasets: List[Dict] + # training sets and weights + train_sets: Dict[str, float] + # validation sets and weights + val_sets: Dict[str, float] # language model to use text_model: str # audio encoder model to use diff --git a/ultravox/training/configs/llama_whisper.yaml b/ultravox/training/configs/llama_whisper.yaml new file mode 100644 index 00000000..8e7c8b3e --- /dev/null +++ b/ultravox/training/configs/llama_whisper.yaml @@ -0,0 +1,105 @@ +# llama3.1-8b + whisper-medium, for development + +exp_name: "llama3.1-8b-whisper" +text_model: "meta-llama/Meta-Llama-3.1-8B-Instruct" +audio_model: "openai/whisper-medium" + +loss_config: + loss_function: "KL_Divergence" + +max_steps: 20 # x8x24 = 2,764,800 + +# This would go in a datasets.yaml file and we could either use pyyaml-include to include it +# or we could just add this logic to the training script. This file can also include its own datasets +# key, with locally defined datasets. +datasets: + librispeech: + path: "fixie-ai/librispeech_asr" + user_template: "<|audio|>" + assistant_template: "" + transcript_template: "{{ text }}" + + # Note the inheritance here + librispeech-clean: + base: "librispeech" # this could also be done via "<<": *librispeech, although that approach is less flexible + name: "clean" + splits: + - "train.100" # 28_539 samples + num_samples: 28_539 + - "train.360" # 104_014 samples + num_samples: 104_014 + + librispeech-other: + base: "librispeech" + subset: "other" + splits: + - "train.500" # 148_688 samples + num_samples: 148_688 + + covost2: + path: "fixie-ai/covost2" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + + # Note the inheritance here + covost2-es_en: + base: "covost2" + subset: "es_en" + splits: + - name: "train" + num_samples: 100000 + - name: "validation" + num_samples: 15531 + user_template_args: + target: "English" + + covost2-en_zh: + base: "covost2" + subset: "en_zh-CN" + splits: + - name: "train" + num_samples: 100000 + - name: "validation" + num_samples: 15531 + user_template_args: + target: "Chinese" + + covost-foo: + base: "covost2" + subset: "foo" + splits: + - name: "eval" + num_samples: 22222 + is_validation: true + + + covost-bar: + base: "covost2" + subset: "foo" + splits: + - name: "eval2" + num_samples: 22222 + is_validation: true + + covost-small: + base: "covost2" + subset: "foo" + splits: + - name: "eval3" + num_samples: 22 + is_validation: true + + +# This is the new approach to weighting, which keeps this out of the dataset config +train_datasets: + librispeech-clean: 0.5 + librispeech-other: 2.0 + covost2-es_en: 1.0 + covost2-en_zh: 1.0 + +val_datasets: + covost2-es_en: 1.0 + covost2-en_zh: 1.0 + covost-foo: 1.0 + covost-bar: 1.0 diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 5dde3f62..98c5cda9 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -142,8 +142,8 @@ def train(args: config_base.TrainConfig): with model_load_context: model = ultravox_model.UltravoxModel(config) - assert model.get_input_embeddings().num_embeddings == len( - text_tokenizer + assert ( + model.get_input_embeddings().num_embeddings == len(text_tokenizer) ), f"Model and tokenizer mismatch: {model.get_input_embeddings().num_embeddings} != {len(text_tokenizer)}" model.language_model.config.use_cache = False @@ -197,6 +197,9 @@ def train(args: config_base.TrainConfig): model.to(device=torch.device(args.device, index=local_rank)) logging.info(f"Using device (world_size): {model.device} ({world_size})") + # Register custom datasets + datasets.register_datasets(args.datasets) + # Prepare dataset, subsetting if needed train_dataset: data.IterableDataset val_datasets: Dict[str, data.IterableDataset] @@ -205,12 +208,12 @@ def train(args: config_base.TrainConfig): # called "matchtrain" that uses the same data as the training set. val_sets = dict( # [("matchtrain", args.data_sets)] # FIXME: see issue https://github.com/fixie-ai/ultravox/issues/58 - [(x, [x]) for x in args.val_sets] - + [(f"text_{x}", [x]) for x in args.val_sets] + [(x, [x]) for x in args.val_sets] + [(f"text_{x}", [x]) for x in args.val_sets] ) train_dataset = prepare_dataset( train_args=args, - dataset_names=args.data_sets, + datasets=args.datasets, + train_sets=args.train_sets, train_on_inputs=args.train_on_inputs, stop_strategy=args.stop_strategy, processor=processor, From cbeddf4d0584e9072d8d228e0a3699dd1f858c0f Mon Sep 17 00:00:00 2001 From: juberti Date: Tue, 15 Oct 2024 12:42:35 -0700 Subject: [PATCH 02/17] v2 --- ultravox/data/datasets.py | 31 +++---------------------------- ultravox/training/train.py | 4 +--- 2 files changed, 4 insertions(+), 31 deletions(-) diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 670db2e2..35c05ff1 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -27,8 +27,6 @@ import enum from ultravox.data import text_proc -from ultravox.evaluation.eval_types import EvalConfig -from ultravox.utils import device_helpers SAMPLE_RATE = 16000 @@ -52,6 +50,8 @@ class VoiceDatasetArgs: batch_size: int = 4 """Batch size for train, eval, or validation.""" + include_audio: bool = True + """Whether to include audio in the samples.""" shuffle: bool = False """Whether to shuffle the dataset.""" shuffle_seed: int = 42 @@ -416,38 +416,13 @@ def _get_audio( assert sampling_rate == SAMPLE_RATE return audio - def _load_audio( - self, base_url: Optional[str], data_dir: Optional[str], filename: str - ) -> np.ndarray: - if base_url is not None: - url = f"{base_url}/{filename}" # hack for GCS bucket naming - try: - with closing(requests.Session()) as session: - response = session.get(url) - response.raise_for_status() - return audio_from_buf(response.content) - except requests.RequestException as e: - raise ValueError( - f"Failed to load audio from URL: {url}. Error: {str(e)}" - ) - elif data_dir is not None: - audio_path = os.path.join(data_dir, filename) - try: - return audio_from_file(audio_path) - except IOError as e: - raise ValueError( - f"Failed to load audio file: {audio_path}. Error: {str(e)}" - ) - else: - raise ValueError("Either base_url or data_dir must be provided") - def _make_sample( self, messages: List[Dict[str, str]], audio: np.ndarray, audio_transcript: Optional[str] = None, ) -> VoiceSample: - if self._config.audio_field is None: + if not self._args.include_audio: return VoiceSample(messages) return VoiceSample(messages, audio, audio_transcript=audio_transcript) diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 98c5cda9..5a46e7ee 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -37,7 +37,6 @@ def prepare_dataset( train_args: config_base.TrainConfig, - dataset_names: List[str], data_args: datasets.VoiceDatasetArgs, processor: ultravox_processing.UltravoxProcessor, train_on_inputs: bool, @@ -212,8 +211,7 @@ def train(args: config_base.TrainConfig): ) train_dataset = prepare_dataset( train_args=args, - datasets=args.datasets, - train_sets=args.train_sets, + datasets=args.train_sets, train_on_inputs=args.train_on_inputs, stop_strategy=args.stop_strategy, processor=processor, From 850b1fc9a1e1a71e6c75581e3f71addb2288ddb7 Mon Sep 17 00:00:00 2001 From: juberti Date: Tue, 15 Oct 2024 12:44:05 -0700 Subject: [PATCH 03/17] sr --- ultravox/data/datasets.py | 7 ++----- ultravox/training/train.py | 9 +++++---- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 35c05ff1..2d6dadd1 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -1,21 +1,20 @@ import abc import base64 import dataclasses +import enum import io import itertools import logging import os import tempfile import warnings -from contextlib import closing from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence import datasets import jinja2 import librosa import numpy as np -import requests import soundfile as sf import streaming as mds import torch @@ -24,8 +23,6 @@ from pydantic import BaseModel from torch.utils import data -import enum - from ultravox.data import text_proc SAMPLE_RATE = 16000 diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 5a46e7ee..1206083e 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -7,7 +7,7 @@ import os import subprocess from datetime import datetime -from typing import Dict, List, Optional +from typing import Dict, Optional import accelerate import datasets as hf_datasets @@ -141,8 +141,8 @@ def train(args: config_base.TrainConfig): with model_load_context: model = ultravox_model.UltravoxModel(config) - assert ( - model.get_input_embeddings().num_embeddings == len(text_tokenizer) + assert model.get_input_embeddings().num_embeddings == len( + text_tokenizer ), f"Model and tokenizer mismatch: {model.get_input_embeddings().num_embeddings} != {len(text_tokenizer)}" model.language_model.config.use_cache = False @@ -207,7 +207,8 @@ def train(args: config_base.TrainConfig): # called "matchtrain" that uses the same data as the training set. val_sets = dict( # [("matchtrain", args.data_sets)] # FIXME: see issue https://github.com/fixie-ai/ultravox/issues/58 - [(x, [x]) for x in args.val_sets] + [(f"text_{x}", [x]) for x in args.val_sets] + [(x, [x]) for x in args.val_sets] + + [(f"text_{x}", [x]) for x in args.val_sets] ) train_dataset = prepare_dataset( train_args=args, From ff5910922771fc9700b1bd7fb93d7982adbfe082 Mon Sep 17 00:00:00 2001 From: juberti Date: Tue, 15 Oct 2024 22:24:36 -0700 Subject: [PATCH 04/17] sr --- ultravox/data/dataset_config.py | 23 -- ultravox/data/datasets.py | 276 ++++++++++--------- ultravox/data/datasets_test.py | 168 +++++++---- ultravox/tools/data_tool.py | 10 +- ultravox/tools/infer_api.py | 2 +- ultravox/tools/infer_tool.py | 10 - ultravox/tools/push_to_hub.py | 2 +- ultravox/training/config_base.py | 25 +- ultravox/training/configs/llama_whisper.yaml | 44 ++- ultravox/training/evaluation.py | 3 - ultravox/training/train.py | 66 ++--- 11 files changed, 304 insertions(+), 325 deletions(-) delete mode 100644 ultravox/data/dataset_config.py diff --git a/ultravox/data/dataset_config.py b/ultravox/data/dataset_config.py deleted file mode 100644 index ea1f0a60..00000000 --- a/ultravox/data/dataset_config.py +++ /dev/null @@ -1,23 +0,0 @@ -import dataclasses -from typing import List, Optional - -from pydantic import BaseModel - - -class DataDictConfig(BaseModel): - # Path to the dataset, or huggingface dataset id - path: str - # Name of the dataset, or huggingface dataset config/subset - name: Optional[str] = None - splits: List[str] = dataclasses.field(default_factory=list) - num_samples: Optional[int] = None - total_samples: int = 1 - weight: float = 1.0 - streaming: bool = True - user_template: str = "<|audio|>" - assistant_template: str = "{{text}}" - transcript_template: str = "{{text}}" - - def __post_init__(self): - if not self.splits: - raise ValueError("At least one split must be provided") diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 2d6dadd1..9d2dd2f0 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -8,10 +8,9 @@ import os import tempfile import warnings -from enum import Enum from typing import Any, Dict, List, Optional, Sequence -import datasets +import datasets as hf_datasets import jinja2 import librosa import numpy as np @@ -20,7 +19,6 @@ import torch import torch.nn.functional as F import transformers -from pydantic import BaseModel from torch.utils import data from ultravox.data import text_proc @@ -63,7 +61,8 @@ def __post_init__(self): self.split = DatasetSplit(self.split.lower()) -class DatasetSplitConfig(BaseModel): +@dataclasses.dataclass +class DatasetSplitConfig: name: str """Name of the split""" num_samples: int @@ -75,35 +74,30 @@ def __post_init__(self): self.is_validation = True -class DatasetConfig(BaseModel): +@dataclasses.dataclass +class DatasetConfig: + base: Optional[str] = None + """Base dataset config to inherit from.""" path: str = "" """Directory of the dataset, or huggingface dataset name; must be set for "generic" datasets. If not set, it is automatically inferred for predefined dataset types.""" subset: Optional[str] = None - """Name of the dataset, or huggingface dataset config/subset name""" - splits: List[DatasetSplit] = [] - """List of splits to use, e.g. [{"name": "train", "num_samples": 1000}, {"name": "validation", "num_samples": 100}]""" + """Name of the dataset, or huggingface dataset config/subset name.""" + splits: List[DatasetSplitConfig] = dataclasses.field(default_factory=list) + """List of splits to use, e.g. [{"name": "train", "num_samples": 1000}, {"name": "validation", "num_samples": 100}].""" user_template: str = "<|audio|>" - """Template for the user's message""" - user_template_args: Dict[str, str] = {} - """Optional arguments (e.g., target language) for the user template""" + """Template for the user message.""" + user_template_args: Dict[str, str] = dataclasses.field(default_factory=dict) + """Optional arguments (e.g., target language) for the user template.""" assistant_template: str = "{{text}}" - """Template for the assistant's message""" + """Template for the assistant message.""" transcript_template: str = "{{text}}" - """Template for the transcript""" + """Template for the transcript.""" audio_field: Optional[str] = "audio" - """Field in the dataset that contains the audio, use None if the dataset does not contain audio""" + """Field in the dataset that contains the audio, use None if the dataset does not contain audio.""" use_mds: bool = False - """Set to True to load the dataset from GCP (using MDS) instead of Hugging Face""" + """Set to True to load the dataset from GCP (using MDS) instead of Hugging Face.""" mds_batch_size: int = 32 - """Batch size for MDS""" - - class Config: - extra = "forbid" - # do not allow undefined parameters - - def model_post_init(self, __context: Any) -> None: - if not self.splits: - raise ValueError("At least one split must be provided") + """Batch size for the dataset when using MDS.""" @dataclasses.dataclass @@ -291,10 +285,6 @@ def __init__(self, args: VoiceDatasetArgs) -> None: super().__init__() self._args = args self._rng = np.random.default_rng(self._args.shuffle_seed) - if True: # device_helpers.get_local_rank() == 0: - logging.info( - f"Created VoiceDataset with config:\n{self._config.model_dump_json(indent=2)}" - ) def _init_dataset(self, dataset: data.Dataset, num_samples: int) -> None: self._dataset = dataset @@ -303,47 +293,56 @@ def _init_dataset(self, dataset: data.Dataset, num_samples: int) -> None: def __len__(self): return self._length - def _load_audio_dataset( + def _load_hf_dataset( self, path: str, name: Optional[str] = None, *, split: Optional[str] = None, - shuffle: Optional[bool] = None, streaming: bool = True, + audio_field: Optional[str] = None, ) -> data.Dataset: - if shuffle is None: - shuffle = self._args.shuffle - if self._args.use_mds: - gcs_path = path.replace("/", "_") - if name: - gcs_path += f"/{name}" - if split: - gcs_path += f"/{split}" - url = f"gs://fixie-datasets/mds/{gcs_path}" - temp_dir = os.path.join( - tempfile.gettempdir(), f"mds_{gcs_path.replace('/', '_')}" - ) - return mds.StreamingDataset( - remote=url, - local=temp_dir, - batch_size=self._args.mds_batch_size, - shuffle=shuffle, - shuffle_seed=self._args.shuffle_seed, - ) - else: - # HF datasets sometimes fails to download due to network issues, so retry a few times. - dataset = datasets.load_dataset( - path, - name, - split=split, - trust_remote_code=True, - streaming=streaming, - download_config=datasets.DownloadConfig(max_retries=10), + # HF datasets sometimes fails to download due to network issues, so retry a few times. + dataset = hf_datasets.load_dataset( + path, + name, + split=split, + trust_remote_code=True, + streaming=streaming, + download_config=hf_datasets.DownloadConfig(max_retries=10), + ) + if audio_field is not None: + dataset = dataset.cast_column( + audio_field, hf_datasets.Audio(sampling_rate=SAMPLE_RATE) ) - if shuffle: - dataset = dataset.shuffle(seed=self._args.shuffle_seed) - return dataset + if self._args.shuffle: + dataset = dataset.shuffle(seed=self._args.shuffle_seed) + return dataset + + def _load_mds_dataset( + self, + path: str, + name: Optional[str] = None, + *, + split: Optional[str] = None, + batch_size: int = 1, + ) -> data.Dataset: + gcs_path = path.replace("/", "_") + if name: + gcs_path += f"/{name}" + if split: + gcs_path += f"/{split}" + url = f"gs://fixie-datasets/mds/{gcs_path}" + temp_dir = os.path.join( + tempfile.gettempdir(), f"mds_{gcs_path.replace('/', '_')}" + ) + return mds.StreamingDataset( + remote=url, + local=temp_dir, + batch_size=batch_size, + shuffle=self._args.shuffle, + shuffle_seed=self._args.shuffle_seed, + ) def __iter__(self): actual_length = 0 @@ -354,23 +353,20 @@ def __iter__(self): f"Sample is None in dataset {self._config.alias} for row {row}" ) - if self._config.audio_field is not None: + if self._args.include_audio: # If audio_field is set, make sure the sample has audio data. if sample.audio is None: - raise ValueError( - f"Audio field ({self._config.audio_field}) is None in dataset {self._config.alias} for sample {sample}" - ) + raise ValueError(f"Audio is None for sample {sample}") if sample.audio.shape[-1] == 0: - raise ValueError( - f"Audio length is 0 in dataset {self._config.alias} for sample {sample}" - ) + raise ValueError(f"Audio length is 0 for sample {sample}") if ( self._args.max_audio_duration_secs is not None and sample.audio.shape[-1] / SAMPLE_RATE > self._args.max_audio_duration_secs ): + duration = sample.audio.shape[-1] / SAMPLE_RATE warnings.warn( - f"Audio length ({sample.audio.shape[-1] / SAMPLE_RATE}s) exceeds max audio duration ({self._args.max_audio_duration_secs}s) in dataset {self._config.alias}, skipping sample." + f"Audio length ({duration}s) exceeds max audio duration ({self._args.max_audio_duration_secs}s), skipping sample." ) continue @@ -378,11 +374,11 @@ def __iter__(self): actual_length += 1 if actual_length == len(self) + 1: warnings.warn( - f"The presumed length {self._length} has been exceeded for dataset {self._config.alias}. Make sure to update." + f"The presumed length {self._length} has been exceeded for split {self._dataset.split}. Make sure to update." ) if actual_length != len(self): warnings.warn( - f"Mismatch between presumed length ({self._length}) and actual length ({actual_length}) for dataset {self._config.alias}. Make sure to update." + f"Mismatch between presumed length ({self._length}) and actual length ({actual_length}) for split {self._dataset.split}. Make sure to update." ) @abc.abstractmethod @@ -395,11 +391,6 @@ def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: def _get_audio( self, row: transformers.BatchFeature, column_name: Optional[str] = "audio" ) -> np.ndarray: - if column_name not in self._config.base_audio_columns: - raise ValueError( - f"Unknown audio column: {column_name}. This is likely a bug and the audio might not be resampled to {SAMPLE_RATE} Hz." - ) - # Hugging Face datasets have an Audio object, with array and sampling_rate fields. # For MDS, this object is flattened into audio_array and audio_sampling_rate fields. if column_name in row: @@ -413,6 +404,11 @@ def _get_audio( assert sampling_rate == SAMPLE_RATE return audio + def _make_messages( + self, user_content: str, assistant_content: str + ) -> List[Dict[str, str]]: + return _get_messages(user_content, assistant_content) + def _make_sample( self, messages: List[Dict[str, str]], @@ -428,22 +424,30 @@ class GenericDataset(VoiceDataset): def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: super().__init__(args) self._config = config - split_names = [ - split.name - for split in config.splits - if split.is_validation == (self._args.split == DatasetSplit.VALIDATION) - ] dsets = [] total_samples = 0 - for split_name in split_names: - ds = self._load_audio_dataset(config.path, config.name, split=split_name) - ds = ds.cast_column( - config.audio_field, datasets.Audio(sampling_rate=SAMPLE_RATE) - ) - dsets.append(ds) - total_samples += len(ds) - dataset = datasets.concatenate_datasets(dsets) + for split in config.splits: + if split.is_validation == (self._args.split == DatasetSplit.VALIDATION): + if not config.use_mds: + ds = self._load_hf_dataset( + config.path, + config.subset, + split=split.name, + audio_field=config.audio_field, + ) + else: + ds = self._load_mds_dataset( + config.path, + name=config.subset, + split=split.name, + batch_size=config.mds_batch_size, + ) + dsets.append(ds) + total_samples += split.num_samples + dataset = ds if len(dsets) == 1 else hf_datasets.concatenate_datasets(dsets) super()._init_dataset(dataset, total_samples) + if True: # device_helpers.get_local_rank() == 0: + logging.info(f"Created GenericDataset with config:\n{self._config}") def _get_sample(self, row) -> Optional[VoiceSample]: try: @@ -453,6 +457,7 @@ def _get_sample(self, row) -> Optional[VoiceSample]: **row, text_proc=text_proc, dataset=self, + include_audio=self._args.include_audio, **self._config.user_template_args, ) assistant_content = jinja2.Template( @@ -478,58 +483,68 @@ def _get_sample(self, row) -> Optional[VoiceSample]: ) -# Making EmptyDataset a SizedIterableDataset to be compatible with using epochs during training. class EmptyDataset(SizedIterableDataset): + def __init__(self, length: int = 1) -> None: + self._length = length + def __iter__(self): return iter([]) def __len__(self): - return 0 + return self._length -DATASET_MAP: Dict[str, Any] = {} +BOOLQ_CONFIG = DatasetConfig( + path="fixie-ai/boolq-audio", + splits=[ + DatasetSplitConfig(name="train", num_samples=10000), + DatasetSplitConfig(name="validation", num_samples=1000), + ], + user_template="{{passage}}\n\n{{'<|audio|>' if include_audio else question}}", + assistant_template="{{'True' if answer else 'False'}}", + transcript_template="{{question}}", +) +DATASET_MAP: Dict[str, DatasetConfig] = { + "boolq": BOOLQ_CONFIG, +} -def register_datasets(datasets: Dict): - for dataset in datasets: - DATASET_MAP[dataset] = create_dataset(dataset, datasets[dataset]) +def register_datasets(datasets: Dict[str, DatasetConfig]): + for name, config in datasets.items(): + DATASET_MAP[name] = config -def create_dataset( - args: VoiceDatasetArgs, config: DatasetConfig -) -> SizedIterableDataset: + +def create_dataset(name: str, args: VoiceDatasetArgs) -> SizedIterableDataset: + assert name in DATASET_MAP, f"Unknown dataset: {name}" configs = [] - while True: + temp: Optional[str] = name + while temp: + config = DATASET_MAP[temp] configs.append(config) - base = config.get("base") - if not base: - break - config = base - merged_config = configs[-1] - for config in configs[:-1]: - merged_config.update(config) - del merged_config["base"] + temp = config.base + merged_config = dataclasses.replace(configs[-1]) + for config in reversed(configs[:-1]): + merged_config = dataclasses.replace(merged_config, **dataclasses.asdict(config)) + merged_config = dataclasses.replace(merged_config, base=None) + if not merged_config.splits: + raise ValueError(f"Dataset {name} has no splits") return GenericDataset(args, merged_config) -class StopStrategy(str, Enum): +class StopStrategy(str, enum.Enum): FIRST_EXHAUSTED = "FIRST_EXHAUSTED" LAST_EXHAUSTED = "LAST_EXHAUSTED" NEVER_STOP = "NEVER_STOP" -@dataclasses.dataclass -class DatasetAndWeight: - dataset: SizedIterableDataset - weight: float - - class InterleaveDataset(SizedIterableDataset): """Interleaves multiple IterableDataset objects based on normalized weights.""" def __init__( self, - datasets: Sequence[DatasetAndWeight], + datasets: Sequence[SizedIterableDataset], + weights: Optional[Sequence[float]] = None, stop_strategy: StopStrategy = StopStrategy.LAST_EXHAUSTED, seed: Optional[int] = 42, static: bool = False, @@ -537,16 +552,18 @@ def __init__( """ Args: datasets: A list of SizedIterableDataset objects. + weights: A list of weights for each dataset. stop_strategy: Strategy for stopping iteration. seed: Optional seed for reproducibility. static: If true, the datasets are interleaved in a static order with equal weights. """ - self._datasets = [dataset for dataset, _ in datasets] + self._datasets = datasets self._rng = np.random.default_rng(seed) self._static = static self._stop_strategy = stop_strategy - weights = [weight for _, weight in datasets] + if weights is None: + weights = [1.0] * len(datasets) total_weight = sum(weights) self._normalized_probs = [w / total_weight for w in weights] @@ -609,31 +626,16 @@ class Range(SizedIterableDataset): """Limits the number of samples from another dataset.""" def __init__( - self, - dataset: data.IterableDataset, - num_samples: Optional[int] = None, - total_samples: Optional[int] = None, + self, dataset: SizedIterableDataset, num_samples: Optional[int] = None ) -> None: self._dataset = dataset - self._num_samples = num_samples - - if isinstance(self._dataset, SizedIterableDataset): - self._estimated_length = len(self._dataset) - else: - if total_samples is None: - raise ValueError( - "total_samples must be provided for non-SizedIterableDataset." - ) - self._estimated_length = total_samples - - if self._num_samples is not None and self._num_samples > self._estimated_length: - # Issuing a warning here instead of raising an error to accomodate for specific classes of VoiceDataset - # Once we migrate entirely to GenericVoiceDataset, we can raise an error here. - warnings.warn("num_samples is greater than total_samples.") + self._length = num_samples or len(dataset) + if self._length > len(dataset): + raise ValueError("num_samples exceeds dataset length.") def __iter__(self): for i, sample in enumerate(self._dataset): - if self._num_samples is not None and i >= self._num_samples: + if i >= self._length: break yield sample diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index faacd3fa..7d861980 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -8,28 +8,22 @@ from torch.utils import data from transformers.feature_extraction_utils import BatchFeature -from ultravox.data import dataset_config from ultravox.data import datasets class FakeSizedIterableDataset(datasets.SizedIterableDataset): """Fake version of datasets.SizedIterableDataset""" - def __init__(self, n, start=0, weight=1, estimated_length=0): + def __init__(self, n, start=0, length=0): self.data = range(start, start + n) - self._weight = weight - self._estimated_length = estimated_length - - @property - def weight(self) -> float: - return self._weight + self._length = length def __iter__(self): for sample in self.data: yield sample def __len__(self): - return self._estimated_length + return self._length class FakeHuggingFaceIterableDataset(hf_datasets.IterableDataset): @@ -43,37 +37,48 @@ def __init__(self, n): } for i in range(n) ] + self._split = "fake" def __iter__(self): return (i for i in self.data) class FakeTranscribeDataset(datasets.VoiceDataset): - """Fake version of our VoiceDataset using a transcribe prompt.""" + """Fake version of our VoiceDataset.""" def __init__(self, n: int, args: Optional[datasets.VoiceDatasetArgs] = None): - super().__init__(args or datasets.VoiceDatasetArgs()) - + super().__init__( + args or datasets.VoiceDatasetArgs(), + ) self._init_dataset(FakeHuggingFaceIterableDataset(n), n) def _get_sample(self, row: BatchFeature) -> Optional[datasets.VoiceSample]: - return self._get_transcribe_sample(row) + messages = self._make_messages("<|audio|>", row["text"]) + return self._make_sample(messages, np.zeros(256), row["text"]) -class FakeGenericDataset(datasets.VoiceDataset): - """Fake version of GenericDataset.""" +class FakeGenericDataset(datasets.GenericDataset): + """Fake version of GenericDataset, hooked to return a FakeHuggingFaceIterableDataset.""" def __init__( self, n: int, - config: dataset_config.DataDictConfig, + config: datasets.DatasetConfig, args: Optional[datasets.VoiceDatasetArgs] = None, ): - super().__init__(args or datasets.VoiceDatasetArgs()) - self._init_dataset(FakeHuggingFaceIterableDataset(n), config.total_samples) + self._n = n + super().__init__(args or datasets.VoiceDatasetArgs(), config) - def _get_sample(self, row: BatchFeature) -> Optional[datasets.VoiceSample]: - return self._get_transcribe_sample(row) + def _load_hf_dataset( + self, + path: str, + name: Optional[str] = None, + *, + split: Optional[str] = None, + streaming: bool = True, + audio_field: Optional[str] = None, + ) -> data.Dataset: + return FakeHuggingFaceIterableDataset(self._n) class FakeDataproc(datasets.Dataproc): @@ -135,10 +140,11 @@ def test_interleaved_never_stop(): def test_interleaved_random(): - ds1 = FakeSizedIterableDataset(4, weight=10) - ds2 = FakeSizedIterableDataset(2, start=10, weight=1) + ds1 = FakeSizedIterableDataset(4) + ds2 = FakeSizedIterableDataset(2, start=10) s = datasets.InterleaveDataset( [ds1, ds2], + [10.0, 1.0], ) # stop_strategy=last_exhausted will stop interleaving when the last dataset is exhausted (attempted after exhaustion) assert list(s) == [ @@ -178,11 +184,13 @@ def test_interleaved_with_multiprocessing(): def test_range(): - ds = FakeSizedIterableDataset(10, estimated_length=10) + ds = FakeSizedIterableDataset(10, length=10) s = datasets.Range(ds, 5) assert len(s) == 5 assert list(s) == [0, 1, 2, 3, 4] - s = datasets.Range(ds, 100) + with pytest.raises(ValueError, match="exceeds dataset length"): + s = datasets.Range(ds, 100) + s = datasets.Range(ds, 10) assert list(s) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] s = datasets.Range(ds) assert len(s) == 10 @@ -195,7 +203,7 @@ def test_transcribe_dataset(): sample = next(iter(ds)) assert isinstance(sample, datasets.VoiceSample) assert sample.messages == [ - {"role": "user", "content": "Transcribe\n<|audio|>"}, + {"role": "user", "content": "<|audio|>"}, {"role": "assistant", "content": "0"}, ] assert np.array_equal(sample.audio, np.zeros(256)) @@ -203,22 +211,87 @@ def test_transcribe_dataset(): assert sample.audio_transcript == "0" -def test_num_prompts(): - ds = FakeTranscribeDataset(5, datasets.VoiceDatasetArgs(num_prompts=3)) - samples = list(ds) - assert samples[0].messages[0]["content"] == "Transcribe\n<|audio|>" - assert ( - samples[1].messages[0]["content"] - == "Repeat exactly what is written here: <|audio|>" +def test_generic_dataset(): + mock_config = datasets.DatasetConfig( + path="mock_path", + splits=[datasets.DatasetSplitConfig(name="fake", num_samples=5)], ) - assert ( - samples[2].messages[0]["content"] - == "Transcribe exactly what is said here\n<|audio|>" + ds = FakeGenericDataset(5, mock_config) + assert len(ds) == 5 + sample = next(iter(ds)) + assert isinstance(sample, datasets.VoiceSample) + assert sample.messages == [ + {"role": "user", "content": "<|audio|>"}, + {"role": "assistant", "content": "0"}, + ] + assert np.array_equal(sample.audio, np.zeros(256)) + assert sample.sample_rate == 16000 + assert sample.audio_transcript == "0" + + +def test_generic_dataset_custom_templates(): + mock_config = datasets.DatasetConfig( + path="mock_path", + splits=[datasets.DatasetSplitConfig(name="fake", num_samples=5)], + user_template="Listen to the following and respond with 'xyzzy':\n<|audio|>", + assistant_template="xyzzy", + transcript_template="{{text}}", ) - assert ( - samples[3].messages[0]["content"] - == "Transcribe exactly what is said here\n<|audio|>" + ds = FakeGenericDataset(5, mock_config) + assert len(ds) == 5 + sample = next(iter(ds)) + assert isinstance(sample, datasets.VoiceSample) + assert sample.messages == [ + { + "role": "user", + "content": "Listen to the following and respond with 'xyzzy':\n<|audio|>", + }, + {"role": "assistant", "content": "xyzzy"}, + ] + assert np.array_equal(sample.audio, np.zeros(256)) + assert sample.sample_rate == 16000 + assert sample.audio_transcript == "0" + + +def test_generic_dataset_length_mismatch(): + mock_config = datasets.DatasetConfig( + path="mock_path", + splits=[datasets.DatasetSplitConfig(name="fake", num_samples=5)], ) + ds = FakeGenericDataset(10, mock_config) + assert len(ds) == 5 + + pattern = r"(has been exceeded|Mismatch between presumed length)" + with pytest.warns(UserWarning, match=pattern): + list(ds) + + mock_config = datasets.DatasetConfig( + path="mock_path", + splits=[datasets.DatasetSplitConfig(name="fake", num_samples=10)], + ) + ds = FakeGenericDataset(5, mock_config) + assert len(ds) == 10 + + with pytest.warns(UserWarning, match="Mismatch between presumed length"): + list(ds) + + +def test_generic_dataset_multiple_splits(): + mock_config = datasets.DatasetConfig( + path="mock_path", + splits=[ + datasets.DatasetSplitConfig(name="train", num_samples=90), + datasets.DatasetSplitConfig(name="validation", num_samples=10), + ], + ) + ds = FakeGenericDataset(100, mock_config) + assert len(ds) == 90 + ds = FakeGenericDataset( + 100, + mock_config, + datasets.VoiceDatasetArgs(split=datasets.DatasetSplit.VALIDATION), + ) + assert len(ds) == 10 def _create_sine_wave( @@ -294,7 +367,7 @@ def test_create_sample__float64(): def test_create_sample__raises_on_unsupported_dtype(): with pytest.raises(AssertionError): array = np.ndarray(shape=(16000,), dtype=np.uint8) - sample = datasets.VoiceSample.from_prompt_and_raw( + _ = datasets.VoiceSample.from_prompt_and_raw( "Transcribe\n<|audio|>", array, 16000 ) @@ -321,20 +394,3 @@ def test_get_messages(): {"role": "user", "content": "B"}, {"role": "assistant", "content": "C"}, ] - - -def test_voice_dataset_size(): - mock_config = dataset_config.DataDictConfig(path="mock_path", total_samples=5) - ds = FakeGenericDataset(10, mock_config) - assert len(ds) == 5 - - pattern = r"(has been exceeded|Mismatch between estimated length)" - with pytest.warns(UserWarning, match=pattern): - list(ds) - - mock_config = dataset_config.DataDictConfig(path="mock_path", total_samples=10) - ds = FakeGenericDataset(5, mock_config) - assert len(ds) == 10 - - with pytest.warns(UserWarning, match="Mismatch between estimated length"): - list(ds) diff --git a/ultravox/tools/data_tool.py b/ultravox/tools/data_tool.py index bc7a565a..c275d0c5 100644 --- a/ultravox/tools/data_tool.py +++ b/ultravox/tools/data_tool.py @@ -11,9 +11,6 @@ parser.add_argument( "--num-samples", "-n", type=int, default=5, help="Number of samples to display" ) -parser.add_argument( - "--num-prompts", type=int, default=1, help="Number of prompts to use" -) parser.add_argument("--play", "-p", action="store_true", help="Play the audio samples") parser.add_argument( "--write", "-w", action="store_true", help="Write audio samples out as WAV files" @@ -21,14 +18,11 @@ parser.add_argument("--playback-rate", "-r", type=float, help="Playback rate") parser.add_argument("--shuffle", "-s", action="store_true", help="Shuffle the samples") parser.add_argument("--seed", type=int, help="Shuffle seed") -parser.add_argument("--mds", action="store_true", help="Use MDS datasets") def main(args: argparse.Namespace): data_args = datasets.VoiceDatasetArgs( - num_prompts=args.num_prompts, shuffle=args.shuffle, - use_mds=args.mds, split=args.data_split, ) if args.seed is not None: @@ -41,8 +35,8 @@ def main(args: argparse.Namespace): assert len(messages) >= 2, f"Bad sample (messages) {len(messages)}" assert messages[-2]["role"] == "user", f"Bad sample (Q role): {messages}" assert messages[-1]["role"] == "assistant", f"Bad sample (A role): {messages}" - answer = messages[-2]["content"].replace("\n", "\\n") - print(f"Q: {messages[-1]['content']} [\"{sample.audio_transcript}\"]") + answer = messages[-1]["content"].replace("\n", "\\n") + print(f"Q: {messages[-2]['content']} [\"{sample.audio_transcript}\"]") print(f"A: {answer}") if args.play: audio = sample.audio diff --git a/ultravox/tools/infer_api.py b/ultravox/tools/infer_api.py index 184ba11c..a1d217f6 100644 --- a/ultravox/tools/infer_api.py +++ b/ultravox/tools/infer_api.py @@ -84,7 +84,7 @@ def _build_messages(self, sample: datasets.VoiceSample): url = datasets.audio_to_data_uri(sample.audio, sample.sample_rate) parts = [ {"type": "text", "text": fragments[0]}, - {"type": "image_url", "image_url": {"url": url}}, + {"type": "audio_url", "audio_url": {"url": url}}, {"type": "text", "text": fragments[1]}, ] last_turn = {"role": "user", "content": parts} diff --git a/ultravox/tools/infer_tool.py b/ultravox/tools/infer_tool.py index 36ea4a0d..edc3cfa5 100644 --- a/ultravox/tools/infer_tool.py +++ b/ultravox/tools/infer_tool.py @@ -51,12 +51,6 @@ class InferArgs: data_split: datasets.DatasetSplit = simple_parsing.field( default=datasets.DatasetSplit.VALIDATION, alias="-s" ) - # Directory for existing data - data_dir: Optional[str] = None - # Use dataset context - context: bool = False - # Load datasets using MDS - mds: bool = False # Number of dataset samples to process num_samples: int = simple_parsing.field(default=1, alias="-n") # Shuffle the dataset @@ -181,12 +175,8 @@ def oneshot_infer(inference: base.VoiceInference, args: InferArgs): def dataset_infer(inference: base.VoiceInference, args: InferArgs): assert args.data_sets, "At least one data set must be provided" ds_args = datasets.VoiceDatasetArgs( - data_dir=args.data_dir, - prompt=args.prompt, include_audio=not args.text_only, - include_context=args.context, shuffle=args.shuffle, - use_mds=args.mds, split=args.data_split, ) if args.seed is not None: diff --git a/ultravox/tools/push_to_hub.py b/ultravox/tools/push_to_hub.py index 01945321..1a4669d0 100644 --- a/ultravox/tools/push_to_hub.py +++ b/ultravox/tools/push_to_hub.py @@ -49,7 +49,7 @@ def main(args: UploadToHubArgs): loaded_pipe = transformers.pipeline( model=args.hf_upload_model, trust_remote_code=True ) - ds = datasets.BoolQDataset(datasets.VoiceDatasetArgs()) + ds = datasets.create_dataset("boolq", datasets.VoiceDatasetArgs()) sample = next(iter(ds)) generated = loaded_pipe( {"audio": sample.audio, "turns": sample.messages[:-1]}, max_new_tokens=10 diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index ce67a0f1..92fa82a8 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -5,12 +5,11 @@ import re import sys from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional import simple_parsing import torch -from ultravox.data import dataset_config from ultravox.data import datasets from ultravox.model import ultravox_config @@ -18,7 +17,7 @@ @dataclasses.dataclass class TrainConfig: # data-defined datasets - datasets: List[Dict] + data_sets: Dict[str, datasets.DatasetConfig] # training sets and weights train_sets: Dict[str, float] # validation sets and weights @@ -28,15 +27,6 @@ class TrainConfig: # audio encoder model to use audio_model: str - # The data_dicts field complements data_sets, allowing for the inclusion of - # new datasets in the config. - # - # Due to simple_parsing's lack of support for containers of dataclass types, - # we first parse the data_dicts as a list of dictionaries. After parsing, - # we convert these dictionaries to DataDictConfig objects using Pydantic - # to enforce type constraints and validation, in the __post_init__ method. - data_dicts: Optional[List[Dict[str, Any]]] = None - do_train: bool = True do_eval: bool = True @@ -102,15 +92,8 @@ class TrainConfig: loss_config: Optional[ultravox_config.LossConfig] = None def __post_init__(self): - if self.data_dicts: - self.data_dicts = [ - dataset_config.DataDictConfig(**data_dict) - for data_dict in self.data_dicts - ] - # For now, self.data_dicts is a hack to allow for the inclusion of new datasets using the - # GenericVoiceDataset class, without changing how existing datasets are specified in - # self.data_sets. In the future, all datasets will be updated to use the DataDictConfig class. - self.data_sets.extend(self.data_dicts) + for name, config in self.data_sets.items(): + self.data_sets[name] = datasets.DatasetConfig(**config) assert self.data_type in ["bfloat16", "float16", "float32"] if self.device == "cuda" and not torch.cuda.is_available(): diff --git a/ultravox/training/configs/llama_whisper.yaml b/ultravox/training/configs/llama_whisper.yaml index 8e7c8b3e..1c007d2d 100644 --- a/ultravox/training/configs/llama_whisper.yaml +++ b/ultravox/training/configs/llama_whisper.yaml @@ -12,7 +12,7 @@ max_steps: 20 # x8x24 = 2,764,800 # This would go in a datasets.yaml file and we could either use pyyaml-include to include it # or we could just add this logic to the training script. This file can also include its own datasets # key, with locally defined datasets. -datasets: +data_sets: librispeech: path: "fixie-ai/librispeech_asr" user_template: "<|audio|>" @@ -22,23 +22,23 @@ datasets: # Note the inheritance here librispeech-clean: base: "librispeech" # this could also be done via "<<": *librispeech, although that approach is less flexible - name: "clean" + subset: "clean" splits: - - "train.100" # 28_539 samples + - name: "train.100" # 28_539 samples num_samples: 28_539 - - "train.360" # 104_014 samples + - name: "train.360" # 104_014 samples num_samples: 104_014 - + librispeech-other: base: "librispeech" subset: "other" splits: - - "train.500" # 148_688 samples + - name: "train.500" # 148_688 samples num_samples: 148_688 covost2: path: "fixie-ai/covost2" - user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" assistant_template: "{{ translation }}" transcript_template: "{{ sentence }}" @@ -48,14 +48,14 @@ datasets: subset: "es_en" splits: - name: "train" - num_samples: 100000 + num_samples: 100000 - name: "validation" num_samples: 15531 user_template_args: target: "English" covost2-en_zh: - base: "covost2" + base: "covost2" subset: "en_zh-CN" splits: - name: "train" @@ -69,36 +69,34 @@ datasets: base: "covost2" subset: "foo" splits: - - name: "eval" - num_samples: 22222 - is_validation: true - + - name: "eval" + num_samples: 22222 + is_validation: true covost-bar: base: "covost2" subset: "foo" splits: - - name: "eval2" - num_samples: 22222 - is_validation: true + - name: "eval2" + num_samples: 22222 + is_validation: true covost-small: base: "covost2" subset: "foo" splits: - - name: "eval3" - num_samples: 22 - is_validation: true - - + - name: "eval3" + num_samples: 22 + is_validation: true + # This is the new approach to weighting, which keeps this out of the dataset config -train_datasets: +train_sets: librispeech-clean: 0.5 librispeech-other: 2.0 covost2-es_en: 1.0 covost2-en_zh: 1.0 -val_datasets: +val_sets: covost2-es_en: 1.0 covost2-en_zh: 1.0 covost-foo: 1.0 diff --git a/ultravox/training/evaluation.py b/ultravox/training/evaluation.py index 36492d67..69deff76 100644 --- a/ultravox/training/evaluation.py +++ b/ultravox/training/evaluation.py @@ -109,7 +109,6 @@ class EvalScenario: def evaluate( inference: infer.LocalInference, - data_dir: Optional[str] = None, num_samples: int = 200, num_procs: int = 8, max_new_tokens: Optional[int] = None, @@ -127,10 +126,8 @@ def evaluate( for task in EVAL_SCENARIOS: ds_args = datasets.VoiceDatasetArgs( - data_dir=data_dir, split=datasets.DatasetSplit.VALIDATION, include_audio=task.include_audio, - include_context=task.include_context, ) ds = datasets.Range(datasets.create_dataset(task.dataset, ds_args), num_samples) diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 1206083e..f33ceab0 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -1,5 +1,4 @@ import contextlib -import copy import dataclasses import gc import glob @@ -18,7 +17,6 @@ import transformers import wandb import wandb.sdk -from torch.utils import data from ultravox.data import datasets from ultravox.model import data_processing @@ -37,6 +35,7 @@ def prepare_dataset( train_args: config_base.TrainConfig, + data_sets_and_weights: Dict[str, float], data_args: datasets.VoiceDatasetArgs, processor: ultravox_processing.UltravoxProcessor, train_on_inputs: bool, @@ -44,7 +43,9 @@ def prepare_dataset( num_samples: Optional[int] = None, include_alt_fields: bool = False, # whether to generate tensors for text-only input (e.g., used for KD training) ) -> datasets.SizedIterableDataset: - data_sets = [datasets.create_dataset(ds, data_args) for ds in dataset_names] + data_names = list(data_sets_and_weights.keys()) + data_weights = list(data_sets_and_weights.values()) + data_sets = [datasets.create_dataset(ds, data_args) for ds in data_names] # If we're using epochs to train, validate the dataset length is appropriate. if train_args.max_steps == 0: for ds in data_sets: @@ -52,7 +53,9 @@ def prepare_dataset( len(ds) > 1 ), f"Dataset {ds} has length {len(ds)} which is too short for epoch training" - interleave = datasets.InterleaveDataset(data_sets, stop_strategy=stop_strategy) + interleave = datasets.InterleaveDataset( + data_sets, data_weights, stop_strategy=stop_strategy + ) ds_with_proc = data_processing.UltravoxDataproc( interleave, processor=processor, @@ -197,71 +200,50 @@ def train(args: config_base.TrainConfig): logging.info(f"Using device (world_size): {model.device} ({world_size})") # Register custom datasets - datasets.register_datasets(args.datasets) + datasets.register_datasets(args.data_sets) # Prepare dataset, subsetting if needed - train_dataset: data.IterableDataset - val_datasets: Dict[str, data.IterableDataset] - # We use multiple validation sets here so that the results are comparable even when training set changes - # To make sure we can compare training and validation loss (e.g. for fine-tuning), we keep a special set - # called "matchtrain" that uses the same data as the training set. - val_sets = dict( - # [("matchtrain", args.data_sets)] # FIXME: see issue https://github.com/fixie-ai/ultravox/issues/58 - [(x, [x]) for x in args.val_sets] - + [(f"text_{x}", [x]) for x in args.val_sets] - ) + train_dataset: datasets.SizedIterableDataset + val_dataset: datasets.SizedIterableDataset train_dataset = prepare_dataset( train_args=args, - datasets=args.train_sets, + data_sets_and_weights=args.train_sets, train_on_inputs=args.train_on_inputs, stop_strategy=args.stop_strategy, processor=processor, num_samples=args.num_samples, data_args=datasets.VoiceDatasetArgs( - num_prompts=args.num_prompts, - data_dir=args.data_dir, shuffle=args.shuffle_data, shuffle_seed=args.shuffle_seed, max_audio_duration_secs=args.max_audio_duration_secs, - use_mds=args.mds, - mds_batch_size=args.batch_size, ), include_alt_fields=model.loss_config.requires_alt_fields, ) if is_master: val_ds_args = datasets.VoiceDatasetArgs( - num_prompts=1, split=datasets.DatasetSplit.VALIDATION, - data_dir=args.data_dir, shuffle=False, max_audio_duration_secs=16, - use_mds=args.mds, - mds_batch_size=args.batch_size, ) - val_ds_args_text = copy.copy(val_ds_args) - val_ds_args_text.include_audio = False - val_datasets = { - k: prepare_dataset( - train_args=args, - dataset_names=val_sets[k], - train_on_inputs=args.train_on_inputs, - stop_strategy=args.stop_strategy, - processor=processor, - num_samples=args.val_num_samples, - data_args=val_ds_args_text if k.startswith("text_") else val_ds_args, - include_alt_fields=model.loss_config.requires_alt_fields, - ) - for k in val_sets - } + val_dataset = prepare_dataset( + train_args=args, + data_sets_and_weights=args.val_sets, + train_on_inputs=args.train_on_inputs, + stop_strategy=args.stop_strategy, + processor=processor, + num_samples=args.val_num_samples, + data_args=val_ds_args, + include_alt_fields=model.loss_config.requires_alt_fields, + ) logging.info( - f"Loaded {args.data_sets} data sets, sample limit: {args.num_samples} (val sample limit: {args.val_num_samples})" + f"Loaded {len(args.train_sets)}) data sets, sample limit: {args.num_samples} (val sample limit: {args.val_num_samples})" ) else: # When using DDP with split_batches=True, the primary process will distribute the batches to the workers # The point of this is to avoid unnecessary data processing/downloading in the workers. # When using epochs to train, emptydataset must have a length equal to the training set train_dataset = datasets.EmptyDataset(len(train_dataset)) - val_datasets = {k: datasets.EmptyDataset() for k in val_sets} + val_dataset = datasets.EmptyDataset(len(val_dataset)) # Set up the data loader data_collator = datasets.DataCollatorForSeq2SeqWithAudio( @@ -273,7 +255,7 @@ def train(args: config_base.TrainConfig): trainer = transformers.Seq2SeqTrainer( model, train_dataset=train_dataset, - eval_dataset=val_datasets, + eval_dataset=val_dataset, data_collator=data_collator, tokenizer=text_tokenizer, args=transformers.Seq2SeqTrainingArguments( From 04aff6f3212b3d65fbf44510e76d0b9f8c1e5335 Mon Sep 17 00:00:00 2001 From: juberti Date: Wed, 16 Oct 2024 13:33:39 -0700 Subject: [PATCH 05/17] More tests --- ultravox/data/datasets.py | 14 ++++--- ultravox/data/datasets_test.py | 49 ++++++++++++++++++++++ ultravox/model/ultravox_processing.py | 2 +- ultravox/training/config_base.py | 20 ++++----- ultravox/training/configs/meta_config.yaml | 15 ++++--- 5 files changed, 78 insertions(+), 22 deletions(-) diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 9d2dd2f0..473c8860 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -19,6 +19,7 @@ import torch import torch.nn.functional as F import transformers +from simple_parsing import helpers from torch.utils import data from ultravox.data import text_proc @@ -38,7 +39,6 @@ class DatasetSplit(str, enum.Enum): VALIDATION = "validation" -# Global arguments for voice datasets. @dataclasses.dataclass class VoiceDatasetArgs: """Global arguments for voice datasets.""" @@ -62,20 +62,22 @@ def __post_init__(self): @dataclasses.dataclass -class DatasetSplitConfig: +class DatasetSplitConfig(helpers.Serializable): name: str """Name of the split""" num_samples: int """Number of samples in the split""" - is_validation: bool = False + split_type: DatasetSplit = DatasetSplit.TRAIN + """Type of split, i.e., train or validation.""" def __post_init__(self): + """Automatically set is_validation if it's a validation split.""" if self.name == "validation": - self.is_validation = True + self.split_type = DatasetSplit.VALIDATION @dataclasses.dataclass -class DatasetConfig: +class DatasetConfig(helpers.Serializable): base: Optional[str] = None """Base dataset config to inherit from.""" path: str = "" @@ -427,7 +429,7 @@ def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: dsets = [] total_samples = 0 for split in config.splits: - if split.is_validation == (self._args.split == DatasetSplit.VALIDATION): + if split.split_type == self._args.split: if not config.use_mds: ds = self._load_hf_dataset( config.path, diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index 7d861980..4351f8c0 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -211,6 +211,55 @@ def test_transcribe_dataset(): assert sample.audio_transcript == "0" +def test_dataset_config(): + config = datasets.DatasetConfig( + path="mock_path", + splits=[ + datasets.DatasetSplitConfig(name="clean", num_samples=5000), + datasets.DatasetSplitConfig(name="other", num_samples=10000), + datasets.DatasetSplitConfig(name="validation", num_samples=1000), + datasets.DatasetSplitConfig( + name="another_validation", + num_samples=1000, + split_type=datasets.DatasetSplit.VALIDATION, + ), + ], + ) + assert config.path == "mock_path" + assert len(config.splits) == 4 + assert config.splits[0].name == "clean" + assert config.splits[0].num_samples == 5000 + assert config.splits[0].split_type == datasets.DatasetSplit.TRAIN + assert config.splits[1].name == "other" + assert config.splits[1].num_samples == 10000 + assert config.splits[1].split_type == datasets.DatasetSplit.TRAIN + assert config.splits[2].name == "validation" + assert config.splits[2].num_samples == 1000 + assert config.splits[2].split_type == datasets.DatasetSplit.VALIDATION + assert config.splits[3].name == "another_validation" + assert config.splits[3].num_samples == 1000 + assert config.splits[3].split_type == datasets.DatasetSplit.VALIDATION + + +def test_dataset_config_serialization(): + config = datasets.DatasetConfig( + path="mock_path", + splits=[ + datasets.DatasetSplitConfig(name="clean", num_samples=5000), + datasets.DatasetSplitConfig(name="other", num_samples=10000), + ], + ) + serialized = config.dumps_yaml() + deserialized = datasets.DatasetConfig.loads_yaml(serialized) + assert isinstance(deserialized, datasets.DatasetConfig) + assert deserialized.path == "mock_path" + assert len(deserialized.splits) == 2 + assert deserialized.splits[0].name == "clean" + assert deserialized.splits[0].num_samples == 5000 + assert deserialized.splits[1].name == "other" + assert deserialized.splits[1].num_samples == 10000 + + def test_generic_dataset(): mock_config = datasets.DatasetConfig( path="mock_path", diff --git a/ultravox/model/ultravox_processing.py b/ultravox/model/ultravox_processing.py index 211f7f0a..3da068f6 100644 --- a/ultravox/model/ultravox_processing.py +++ b/ultravox/model/ultravox_processing.py @@ -62,7 +62,7 @@ def __init__( super().__init__(audio_processor=audio_processor, tokenizer=tokenizer) @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): config: UltravoxConfig = transformers.AutoConfig.from_pretrained( pretrained_model_name_or_path, **kwargs ) diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 92fa82a8..fbdf5116 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -16,16 +16,19 @@ @dataclasses.dataclass class TrainConfig: + # Language model to use + text_model: str + # Audio encoder model to use + audio_model: str + # data-defined datasets - data_sets: Dict[str, datasets.DatasetConfig] + data_sets: Dict[str, datasets.DatasetConfig] = dataclasses.field( + default_factory=dict + ) # training sets and weights - train_sets: Dict[str, float] + train_sets: Dict[str, float] = dataclasses.field(default_factory=dict) # validation sets and weights - val_sets: Dict[str, float] - # language model to use - text_model: str - # audio encoder model to use - audio_model: str + val_sets: Dict[str, float] = dataclasses.field(default_factory=dict) do_train: bool = True do_eval: bool = True @@ -92,9 +95,6 @@ class TrainConfig: loss_config: Optional[ultravox_config.LossConfig] = None def __post_init__(self): - for name, config in self.data_sets.items(): - self.data_sets[name] = datasets.DatasetConfig(**config) - assert self.data_type in ["bfloat16", "float16", "float32"] if self.device == "cuda" and not torch.cuda.is_available(): self.device = "mps" if torch.backends.mps.is_available() else "cpu" diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index d3764d29..274d4aed 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -1,8 +1,13 @@ text_model: "meta-llama/Meta-Llama-3-8B-Instruct" audio_model: "facebook/wav2vec2-base-960h" -data_sets: ["gigaspeech"] -val_sets: ["heysquad_human", "anyinstruct", "soda", "peoplespeech"] +train_sets: + gigaspeech: 1.0 +val_sets: + heysquad_human: 1.0 + anyinstruct: 1.0 + soda: 1.0 + peoplespeech: 1.0 stop_strategy: "LAST_EXHAUSTED" train_on_inputs: False @@ -15,9 +20,9 @@ eval_num_samples: 2000 eval_max_new_tokens: 32 eval_num_procs: 16 -optimizer: "adamw_torch" # options: adamw_torch, adamw_bnb_8bit -lr_scheduler: "cosine" # options: linear, cosine, cosine_with_restarts, etc. -lr: 2.e-3 +optimizer: "adamw_torch" # options: adamw_torch, adamw_bnb_8bit +lr_scheduler: "cosine" # options: linear, cosine, cosine_with_restarts, etc. +lr: 2.e-3 grad_accum_steps: 1 lr_warmup_steps: 1000 max_steps: 10_000 From 3ec1bf768eef987d14eff34b82c62c398292fa94 Mon Sep 17 00:00:00 2001 From: juberti Date: Wed, 16 Oct 2024 17:50:36 -0700 Subject: [PATCH 06/17] data_tool testing --- ultravox/data/datasets.py | 385 +++++++++++++++++++++++++++++++++----- 1 file changed, 336 insertions(+), 49 deletions(-) diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 473c8860..955b9b8d 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -80,27 +80,44 @@ def __post_init__(self): class DatasetConfig(helpers.Serializable): base: Optional[str] = None """Base dataset config to inherit from.""" - path: str = "" + path: Optional[str] = None """Directory of the dataset, or huggingface dataset name; must be set for "generic" datasets. If not set, it is automatically inferred for predefined dataset types.""" subset: Optional[str] = None """Name of the dataset, or huggingface dataset config/subset name.""" - splits: List[DatasetSplitConfig] = dataclasses.field(default_factory=list) + splits: Optional[List[DatasetSplitConfig]] = None """List of splits to use, e.g. [{"name": "train", "num_samples": 1000}, {"name": "validation", "num_samples": 100}].""" - user_template: str = "<|audio|>" + user_template: Optional[str] = None """Template for the user message.""" - user_template_args: Dict[str, str] = dataclasses.field(default_factory=dict) + user_template_args: Optional[Dict[str, str]] = None """Optional arguments (e.g., target language) for the user template.""" - assistant_template: str = "{{text}}" + assistant_template: Optional[str] = None """Template for the assistant message.""" - transcript_template: str = "{{text}}" + transcript_template: Optional[str] = None """Template for the transcript.""" - audio_field: Optional[str] = "audio" + audio_field: Optional[str] = None """Field in the dataset that contains the audio, use None if the dataset does not contain audio.""" - use_mds: bool = False + use_mds: Optional[bool] = None """Set to True to load the dataset from GCP (using MDS) instead of Hugging Face.""" - mds_batch_size: int = 32 + mds_batch_size: Optional[int] = None """Batch size for the dataset when using MDS.""" + def __post_init__(self): + """Set defaults only if this is a root config, so that said defaults in a subclass don't act as overrides.""" + DEFAULTS = { + "splits": [], + "user_template": "<|audio|>", + "user_template_args": {}, + "assistant_template": "{{text}}", + "transcript_template": "{{text}}", + "audio_field": "audio", + "use_mds": False, + "mds_batch_size": 32, + } + if self.base is None: + for attr, default_value in DEFAULTS.items(): + if getattr(self, attr) is None: + setattr(self, attr, default_value) + @dataclasses.dataclass class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq): @@ -448,8 +465,6 @@ def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: total_samples += split.num_samples dataset = ds if len(dsets) == 1 else hf_datasets.concatenate_datasets(dsets) super()._init_dataset(dataset, total_samples) - if True: # device_helpers.get_local_rank() == 0: - logging.info(f"Created GenericDataset with config:\n{self._config}") def _get_sample(self, row) -> Optional[VoiceSample]: try: @@ -496,44 +511,6 @@ def __len__(self): return self._length -BOOLQ_CONFIG = DatasetConfig( - path="fixie-ai/boolq-audio", - splits=[ - DatasetSplitConfig(name="train", num_samples=10000), - DatasetSplitConfig(name="validation", num_samples=1000), - ], - user_template="{{passage}}\n\n{{'<|audio|>' if include_audio else question}}", - assistant_template="{{'True' if answer else 'False'}}", - transcript_template="{{question}}", -) - -DATASET_MAP: Dict[str, DatasetConfig] = { - "boolq": BOOLQ_CONFIG, -} - - -def register_datasets(datasets: Dict[str, DatasetConfig]): - for name, config in datasets.items(): - DATASET_MAP[name] = config - - -def create_dataset(name: str, args: VoiceDatasetArgs) -> SizedIterableDataset: - assert name in DATASET_MAP, f"Unknown dataset: {name}" - configs = [] - temp: Optional[str] = name - while temp: - config = DATASET_MAP[temp] - configs.append(config) - temp = config.base - merged_config = dataclasses.replace(configs[-1]) - for config in reversed(configs[:-1]): - merged_config = dataclasses.replace(merged_config, **dataclasses.asdict(config)) - merged_config = dataclasses.replace(merged_config, base=None) - if not merged_config.splits: - raise ValueError(f"Dataset {name} has no splits") - return GenericDataset(args, merged_config) - - class StopStrategy(str, enum.Enum): FIRST_EXHAUSTED = "FIRST_EXHAUSTED" LAST_EXHAUSTED = "LAST_EXHAUSTED" @@ -643,3 +620,313 @@ def __iter__(self): def __len__(self): return self._length + + +CONTINUATION_USER_TEMPLATE = ( + "Continue the following text using less than 50 words:\n\n<|audio|>" +) +CONTINUATION_ASSISTANT_TEMPLATE = "{{continuation}}" +TRANSCRIPTION_USER_TEMPLATE = "Transcribe\n<|audio|>" + +BOOLQ_CONFIG = DatasetConfig( + path="fixie-ai/boolq-audio", + splits=[ + DatasetSplitConfig(name="train", num_samples=10000), + DatasetSplitConfig(name="validation", num_samples=1000), + ], + user_template="{{passage}}\n\n{{'<|audio|>' if include_audio else question}}", + assistant_template="{{'True' if answer else 'False'}}", + transcript_template="{{question}}", +) + +CV_BASE_CONFIG = DatasetConfig( + path="fixie-ai/common_voice_17_0", + assistant_template="{{sentence}}", + transcript_template="{{sentence}}", +) + +CV_EN_CONFIG = DatasetConfig( + base="commonvoice", + subset="en", + splits=[DatasetSplitConfig(name="train", num_samples=1_101_170)], +) + +CV_AR_CONFIG = DatasetConfig( + base="commonvoice", + subset="ar", + splits=[DatasetSplitConfig(name="train", num_samples=28_369)], +) + +CV_DE_CONFIG = DatasetConfig( + base="commonvoice", + subset="de", + splits=[DatasetSplitConfig(name="train", num_samples=589_100)], +) + +CV_ES_CONFIG = DatasetConfig( + base="commonvoice", + subset="es", + splits=[DatasetSplitConfig(name="train", num_samples=336_846)], +) + +CV_FR_CONFIG = DatasetConfig( + base="commonvoice", + subset="fr", + splits=[DatasetSplitConfig(name="train", num_samples=558_054)], +) + +CV_IT_CONFIG = DatasetConfig( + base="commonvoice", + subset="it", + splits=[DatasetSplitConfig(name="train", num_samples=169_771)], +) + +CV_JA_CONFIG = DatasetConfig( + base="commonvoice", + subset="ja", + splits=[DatasetSplitConfig(name="train", num_samples=10_039)], +) + +CV_PT_CONFIG = DatasetConfig( + base="commonvoice", + subset="pt", + splits=[DatasetSplitConfig(name="train", num_samples=21_968)], +) + +CV_RU_CONFIG = DatasetConfig( + base="commonvoice", + subset="ru", + splits=[DatasetSplitConfig(name="train", num_samples=26_377)], +) + +GS_XL_CONFIG = DatasetConfig( + path="speechcolab/gigaspeech", + subset="xl", + splits=[ + DatasetSplitConfig(name="train", num_samples=1_000_000), + DatasetSplitConfig(name="validation", num_samples=10_000), + ], + assistant_template="{{text_proc.format_asr_text(text)}}", + transcript_template="{{text_proc.format_asr_text(text)}}", +) + +LS_BASE_CONFIG = DatasetConfig( + path="fixie-ai/librispeech_asr", + assistant_template="{{text_proc.format_asr_text(text)}}", + transcript_template="{{text_proc.format_asr_text(text)}}", +) + +LS_CLEAN_CONFIG = DatasetConfig( + base="librispeech", + subset="clean", + splits=[ + DatasetSplitConfig(name="train.100", num_samples=28_539), + DatasetSplitConfig(name="train.360", num_samples=104_014), + ], + user_template_args={"foo": "bar"}, +) + +LS_OTHER_CONFIG = DatasetConfig( + base="librispeech", + subset="other", + splits=[ + DatasetSplitConfig(name="train.500", num_samples=148_688), + ], +) + +PS_CLEAN_CONFIG = DatasetConfig( + path="fixie-ai/peoples_speech", + subset="clean", + splits=[ + DatasetSplitConfig(name="train", num_samples=1_000_000), + DatasetSplitConfig(name="validation", num_samples=10_000), + ], +) + +VP_EN_CONFIG = DatasetConfig( + path="facebook/voxpopuli", + subset="en", + splits=[ + DatasetSplitConfig(name="train", num_samples=1_000_000), + DatasetSplitConfig(name="validation", num_samples=10_000), + ], + assistant_template="{{raw_text}}", + transcript_template="{{raw_text}}", +) + +CV_EN_TRANS_CONFIG = DatasetConfig( + base="commonvoice-en", user_template=TRANSCRIPTION_USER_TEMPLATE +) +CV_AR_TRANS_CONFIG = DatasetConfig( + base="commonvoice-ar", user_template=TRANSCRIPTION_USER_TEMPLATE +) +CV_DE_TRANS_CONFIG = DatasetConfig( + base="commonvoice-de", user_template=TRANSCRIPTION_USER_TEMPLATE +) +CV_ES_TRANS_CONFIG = DatasetConfig( + base="commonvoice-es", user_template=TRANSCRIPTION_USER_TEMPLATE +) +CV_FR_TRANS_CONFIG = DatasetConfig( + base="commonvoice-fr", user_template=TRANSCRIPTION_USER_TEMPLATE +) +CV_IT_TRANS_CONFIG = DatasetConfig( + base="commonvoice-it", user_template=TRANSCRIPTION_USER_TEMPLATE +) +CV_JA_TRANS_CONFIG = DatasetConfig( + base="commonvoice-ja", user_template=TRANSCRIPTION_USER_TEMPLATE +) +CV_PT_TRANS_CONFIG = DatasetConfig( + base="commonvoice-pt", user_template=TRANSCRIPTION_USER_TEMPLATE +) +CV_RU_TRANS_CONFIG = DatasetConfig( + base="commonvoice-ru", user_template=TRANSCRIPTION_USER_TEMPLATE +) + +LS_CLEAN_TRANS_CONFIG = DatasetConfig( + base="librispeech-clean", user_template=TRANSCRIPTION_USER_TEMPLATE +) +LS_OTHER_TRANS_CONFIG = DatasetConfig( + base="librispeech-other", user_template=TRANSCRIPTION_USER_TEMPLATE +) + +PS_CLEAN_TRANS_CONFIG = DatasetConfig( + base="peoples_speech", user_template=TRANSCRIPTION_USER_TEMPLATE +) + +CV_EN_CONT_CONFIG = DatasetConfig( + base="commonvoice-en", + user_template=CONTINUATION_USER_TEMPLATE, + assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, +) +CV_AR_CONT_CONFIG = DatasetConfig( + base="commonvoice-ar", + user_template=CONTINUATION_USER_TEMPLATE, + assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, +) +CV_DE_CONT_CONFIG = DatasetConfig( + base="commonvoice-de", + user_template=CONTINUATION_USER_TEMPLATE, + assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, +) +CV_ES_CONT_CONFIG = DatasetConfig( + base="commonvoice-es", + user_template=CONTINUATION_USER_TEMPLATE, + assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, +) +CV_FR_CONT_CONFIG = DatasetConfig( + base="commonvoice-fr", + user_template=CONTINUATION_USER_TEMPLATE, + assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, +) +CV_IT_CONT_CONFIG = DatasetConfig( + base="commonvoice-it", + user_template=CONTINUATION_USER_TEMPLATE, + assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, +) +CV_JA_CONT_CONFIG = DatasetConfig( + base="commonvoice-ja", + user_template=CONTINUATION_USER_TEMPLATE, + assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, +) +CV_PT_CONT_CONFIG = DatasetConfig( + base="commonvoice-pt", + user_template=CONTINUATION_USER_TEMPLATE, + assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, +) +CV_RU_CONT_CONFIG = DatasetConfig( + base="commonvoice-ru", + user_template=CONTINUATION_USER_TEMPLATE, + assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, +) + +LS_CLEAN_CONT_CONFIG = DatasetConfig( + base="librispeech-clean", + user_template=CONTINUATION_USER_TEMPLATE, + assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, +) + +LS_OTHER_CONT_CONFIG = DatasetConfig( + base="librispeech-other", + user_template=CONTINUATION_USER_TEMPLATE, + assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, +) + +PS_CLEAN_CONT_CONFIG = DatasetConfig( + base="peoplespeech", + user_template=CONTINUATION_USER_TEMPLATE, + assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, +) + + +DATASET_MAP: Dict[str, DatasetConfig] = { + "boolq": BOOLQ_CONFIG, + "commonvoice": CV_BASE_CONFIG, + "commonvoice-en": CV_EN_CONFIG, + "commonvoice-ar": CV_AR_CONFIG, + "commonvoice-de": CV_DE_CONFIG, + "commonvoice-es": CV_ES_CONFIG, + "commonvoice-fr": CV_FR_CONFIG, + "commonvoice-it": CV_IT_CONFIG, + "commonvoice-ja": CV_JA_CONFIG, + "commonvoice-pt": CV_PT_CONFIG, + "commonvoice-ru": CV_RU_CONFIG, + "commonvoice-en-transcription": CV_EN_TRANS_CONFIG, + "commonvoice-ar-transcription": CV_AR_TRANS_CONFIG, + "commonvoice-de-transcription": CV_DE_TRANS_CONFIG, + "commonvoice-es-transcription": CV_ES_TRANS_CONFIG, + "commonvoice-fr-transcription": CV_FR_TRANS_CONFIG, + "commonvoice-it-transcription": CV_IT_TRANS_CONFIG, + "commonvoice-ja-transcription": CV_JA_TRANS_CONFIG, + "commonvoice-pt-transcription": CV_PT_TRANS_CONFIG, + "commonvoice-ru-transcription": CV_RU_TRANS_CONFIG, + "commonvoice-en-continuation": CV_EN_CONT_CONFIG, + "commonvoice-ar-continuation": CV_AR_CONT_CONFIG, + "commonvoice-de-continuation": CV_DE_CONT_CONFIG, + "commonvoice-es-continuation": CV_ES_CONT_CONFIG, + "commonvoice-fr-continuation": CV_FR_CONT_CONFIG, + "commonvoice-it-continuation": CV_IT_CONT_CONFIG, + "commonvoice-ja-continuation": CV_JA_CONT_CONFIG, + "commonvoice-pt-continuation": CV_PT_CONT_CONFIG, + "commonvoice-ru-continuation": CV_RU_CONT_CONFIG, + "gigaspeech": GS_XL_CONFIG, + "librispeech": LS_BASE_CONFIG, + "librispeech-clean": LS_CLEAN_CONFIG, + "librispeech-other": LS_OTHER_CONFIG, + "librispeech-clean-transcription": LS_CLEAN_TRANS_CONFIG, + "librispeech-other-transcription": LS_OTHER_TRANS_CONFIG, + "librispeech-clean-continuation": LS_CLEAN_CONT_CONFIG, + "librispeech-other-continuation": LS_OTHER_CONT_CONFIG, + "peoplespeech": PS_CLEAN_CONFIG, + "peoplespeech-transcription": PS_CLEAN_TRANS_CONFIG, + "peoplespeech-continuation": PS_CLEAN_CONT_CONFIG, + "voxpopuli": VP_EN_CONFIG, +} + + +def register_datasets(datasets: Dict[str, DatasetConfig]): + for name, config in datasets.items(): + DATASET_MAP[name] = config + + +def create_dataset(name: str, args: VoiceDatasetArgs) -> SizedIterableDataset: + assert name in DATASET_MAP, f"Unknown dataset: {name}" + # Make a list of configs from root->base. + configs = [] + temp: Optional[str] = name + while temp: + config = DATASET_MAP[temp] + configs.insert(0, config) + temp = config.base + # Set the root config, and then apply any non-None overrides from the subclasses. + merged_config = dataclasses.replace(configs[0]) + for config in configs[1:]: + for field in dataclasses.fields(config): + value = getattr(config, field.name) + if field.name != "base" and value is not None: + merged_config = dataclasses.replace( + merged_config, **{field.name: value} + ) + # Sanity check. + if not merged_config.splits: + raise ValueError(f"Dataset {name} has no splits") + return GenericDataset(args, merged_config) From 76fec6eeab37ccf8993a82090e509802e56d6c47 Mon Sep 17 00:00:00 2001 From: juberti Date: Thu, 17 Oct 2024 14:55:19 -0700 Subject: [PATCH 07/17] cr --- ultravox/data/datasets.py | 226 ++++++++++++------ ultravox/data/datasets_test.py | 79 ++++-- ultravox/training/config_base.py | 41 +++- ultravox/training/configs/llama_whisper.yaml | 103 -------- ultravox/training/configs/meta_config.yaml | 8 +- ultravox/training/configs/release_config.yaml | 216 +++-------------- ultravox/training/train.py | 18 +- 7 files changed, 284 insertions(+), 407 deletions(-) delete mode 100644 ultravox/training/configs/llama_whisper.yaml diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 955b9b8d..b3015a2b 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -36,6 +36,7 @@ class DatasetSplit(str, enum.Enum): TRAIN = "train" + TEST = "test" VALIDATION = "validation" @@ -64,7 +65,7 @@ def __post_init__(self): @dataclasses.dataclass class DatasetSplitConfig(helpers.Serializable): name: str - """Name of the split""" + """Name of the split.""" num_samples: int """Number of samples in the split""" split_type: DatasetSplit = DatasetSplit.TRAIN @@ -72,12 +73,16 @@ class DatasetSplitConfig(helpers.Serializable): def __post_init__(self): """Automatically set is_validation if it's a validation split.""" - if self.name == "validation": + if self.name == "test": + self.split_type = DatasetSplit.TEST + elif self.name == "validation": self.split_type = DatasetSplit.VALIDATION @dataclasses.dataclass class DatasetConfig(helpers.Serializable): + name: str + """Name of the dataset.""" base: Optional[str] = None """Base dataset config to inherit from.""" path: Optional[str] = None @@ -441,6 +446,9 @@ def _make_sample( class GenericDataset(VoiceDataset): def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: + assert config.splits is not None + assert config.path is not None + assert config.mds_batch_size is not None super().__init__(args) self._config = config dsets = [] @@ -467,6 +475,10 @@ def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: super()._init_dataset(dataset, total_samples) def _get_sample(self, row) -> Optional[VoiceSample]: + assert self._config.user_template is not None + assert self._config.user_template_args is not None + assert self._config.assistant_template is not None + assert self._config.transcript_template is not None try: user_content = jinja2.Template( self._config.user_template, undefined=jinja2.StrictUndefined @@ -629,6 +641,7 @@ def __len__(self): TRANSCRIPTION_USER_TEMPLATE = "Transcribe\n<|audio|>" BOOLQ_CONFIG = DatasetConfig( + name="boolq", path="fixie-ai/boolq-audio", splits=[ DatasetSplitConfig(name="train", num_samples=10000), @@ -640,66 +653,77 @@ def __len__(self): ) CV_BASE_CONFIG = DatasetConfig( + name="commonvoice", path="fixie-ai/common_voice_17_0", assistant_template="{{sentence}}", transcript_template="{{sentence}}", ) CV_EN_CONFIG = DatasetConfig( + name="commonvoice-en", base="commonvoice", subset="en", splits=[DatasetSplitConfig(name="train", num_samples=1_101_170)], ) CV_AR_CONFIG = DatasetConfig( + name="commonvoice-ar", base="commonvoice", subset="ar", splits=[DatasetSplitConfig(name="train", num_samples=28_369)], ) CV_DE_CONFIG = DatasetConfig( + name="commonvoice-de", base="commonvoice", subset="de", splits=[DatasetSplitConfig(name="train", num_samples=589_100)], ) CV_ES_CONFIG = DatasetConfig( + name="commonvoice-es", base="commonvoice", subset="es", splits=[DatasetSplitConfig(name="train", num_samples=336_846)], ) CV_FR_CONFIG = DatasetConfig( + name="commonvoice-fr", base="commonvoice", subset="fr", splits=[DatasetSplitConfig(name="train", num_samples=558_054)], ) CV_IT_CONFIG = DatasetConfig( + name="commonvoice-it", base="commonvoice", subset="it", splits=[DatasetSplitConfig(name="train", num_samples=169_771)], ) CV_JA_CONFIG = DatasetConfig( + name="commonvoice-ja", base="commonvoice", subset="ja", splits=[DatasetSplitConfig(name="train", num_samples=10_039)], ) CV_PT_CONFIG = DatasetConfig( + name="commonvoice-pt", base="commonvoice", subset="pt", splits=[DatasetSplitConfig(name="train", num_samples=21_968)], ) CV_RU_CONFIG = DatasetConfig( + name="commonvoice-ru", base="commonvoice", subset="ru", splits=[DatasetSplitConfig(name="train", num_samples=26_377)], ) GS_XL_CONFIG = DatasetConfig( + name="gigaspeech", path="speechcolab/gigaspeech", subset="xl", splits=[ @@ -711,22 +735,24 @@ def __len__(self): ) LS_BASE_CONFIG = DatasetConfig( + name="librispeech", path="fixie-ai/librispeech_asr", assistant_template="{{text_proc.format_asr_text(text)}}", transcript_template="{{text_proc.format_asr_text(text)}}", ) LS_CLEAN_CONFIG = DatasetConfig( + name="librispeech-clean", base="librispeech", subset="clean", splits=[ DatasetSplitConfig(name="train.100", num_samples=28_539), DatasetSplitConfig(name="train.360", num_samples=104_014), ], - user_template_args={"foo": "bar"}, ) LS_OTHER_CONFIG = DatasetConfig( + name="librispeech-other", base="librispeech", subset="other", splits=[ @@ -735,6 +761,7 @@ def __len__(self): ) PS_CLEAN_CONFIG = DatasetConfig( + name="peoplespeech", path="fixie-ai/peoples_speech", subset="clean", splits=[ @@ -744,6 +771,7 @@ def __len__(self): ) VP_EN_CONFIG = DatasetConfig( + name="voxpopuli-en", path="facebook/voxpopuli", subset="en", splits=[ @@ -755,178 +783,230 @@ def __len__(self): ) CV_EN_TRANS_CONFIG = DatasetConfig( - base="commonvoice-en", user_template=TRANSCRIPTION_USER_TEMPLATE + name="commonvoice-en-transcription", + base="commonvoice-en", + user_template=TRANSCRIPTION_USER_TEMPLATE, ) CV_AR_TRANS_CONFIG = DatasetConfig( - base="commonvoice-ar", user_template=TRANSCRIPTION_USER_TEMPLATE + name="commonvoice-ar-transcription", + base="commonvoice-ar", + user_template=TRANSCRIPTION_USER_TEMPLATE, ) CV_DE_TRANS_CONFIG = DatasetConfig( - base="commonvoice-de", user_template=TRANSCRIPTION_USER_TEMPLATE + name="commonvoice-de-transcription", + base="commonvoice-de", + user_template=TRANSCRIPTION_USER_TEMPLATE, ) CV_ES_TRANS_CONFIG = DatasetConfig( - base="commonvoice-es", user_template=TRANSCRIPTION_USER_TEMPLATE + name="commonvoice-es-transcription", + base="commonvoice-es", + user_template=TRANSCRIPTION_USER_TEMPLATE, ) CV_FR_TRANS_CONFIG = DatasetConfig( - base="commonvoice-fr", user_template=TRANSCRIPTION_USER_TEMPLATE + name="commonvoice-fr-transcription", + base="commonvoice-fr", + user_template=TRANSCRIPTION_USER_TEMPLATE, ) CV_IT_TRANS_CONFIG = DatasetConfig( - base="commonvoice-it", user_template=TRANSCRIPTION_USER_TEMPLATE + name="commonvoice-it-transcription", + base="commonvoice-it", + user_template=TRANSCRIPTION_USER_TEMPLATE, ) CV_JA_TRANS_CONFIG = DatasetConfig( - base="commonvoice-ja", user_template=TRANSCRIPTION_USER_TEMPLATE + name="commonvoice-ja-transcription", + base="commonvoice-ja", + user_template=TRANSCRIPTION_USER_TEMPLATE, ) CV_PT_TRANS_CONFIG = DatasetConfig( - base="commonvoice-pt", user_template=TRANSCRIPTION_USER_TEMPLATE + name="commonvoice-pt-transcription", + base="commonvoice-pt", + user_template=TRANSCRIPTION_USER_TEMPLATE, ) CV_RU_TRANS_CONFIG = DatasetConfig( - base="commonvoice-ru", user_template=TRANSCRIPTION_USER_TEMPLATE + name="commonvoice-ru-transcription", + base="commonvoice-ru", + user_template=TRANSCRIPTION_USER_TEMPLATE, ) LS_CLEAN_TRANS_CONFIG = DatasetConfig( - base="librispeech-clean", user_template=TRANSCRIPTION_USER_TEMPLATE + name="librispeech-clean-transcription", + base="librispeech-clean", + user_template=TRANSCRIPTION_USER_TEMPLATE, ) LS_OTHER_TRANS_CONFIG = DatasetConfig( - base="librispeech-other", user_template=TRANSCRIPTION_USER_TEMPLATE + name="librispeech-other-transcription", + base="librispeech-other", + user_template=TRANSCRIPTION_USER_TEMPLATE, ) PS_CLEAN_TRANS_CONFIG = DatasetConfig( - base="peoples_speech", user_template=TRANSCRIPTION_USER_TEMPLATE + name="peoplespeech-clean-transcription", + base="peoplespeech", + user_template=TRANSCRIPTION_USER_TEMPLATE, ) CV_EN_CONT_CONFIG = DatasetConfig( + name="commonvoice-en-continuation", base="commonvoice-en", user_template=CONTINUATION_USER_TEMPLATE, assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, ) CV_AR_CONT_CONFIG = DatasetConfig( + name="commonvoice-ar-continuation", base="commonvoice-ar", user_template=CONTINUATION_USER_TEMPLATE, assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, ) CV_DE_CONT_CONFIG = DatasetConfig( + name="commonvoice-de-continuation", base="commonvoice-de", user_template=CONTINUATION_USER_TEMPLATE, assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, ) CV_ES_CONT_CONFIG = DatasetConfig( + name="commonvoice-es-continuation", base="commonvoice-es", user_template=CONTINUATION_USER_TEMPLATE, assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, ) CV_FR_CONT_CONFIG = DatasetConfig( + name="commonvoice-fr-continuation", base="commonvoice-fr", user_template=CONTINUATION_USER_TEMPLATE, assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, ) CV_IT_CONT_CONFIG = DatasetConfig( + name="commonvoice-it-continuation", base="commonvoice-it", user_template=CONTINUATION_USER_TEMPLATE, assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, ) CV_JA_CONT_CONFIG = DatasetConfig( + name="commonvoice-ja-continuation", base="commonvoice-ja", user_template=CONTINUATION_USER_TEMPLATE, assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, ) CV_PT_CONT_CONFIG = DatasetConfig( + name="commonvoice-pt-continuation", base="commonvoice-pt", user_template=CONTINUATION_USER_TEMPLATE, assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, ) CV_RU_CONT_CONFIG = DatasetConfig( + name="commonvoice-ru-continuation", base="commonvoice-ru", user_template=CONTINUATION_USER_TEMPLATE, assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, ) LS_CLEAN_CONT_CONFIG = DatasetConfig( + name="librispeech-clean-continuation", base="librispeech-clean", user_template=CONTINUATION_USER_TEMPLATE, assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, ) - LS_OTHER_CONT_CONFIG = DatasetConfig( + name="librispeech-other-continuation", base="librispeech-other", user_template=CONTINUATION_USER_TEMPLATE, assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, ) PS_CLEAN_CONT_CONFIG = DatasetConfig( + name="peoplespeech-clean-continuation", base="peoplespeech", user_template=CONTINUATION_USER_TEMPLATE, assistant_template=CONTINUATION_ASSISTANT_TEMPLATE, ) - -DATASET_MAP: Dict[str, DatasetConfig] = { - "boolq": BOOLQ_CONFIG, - "commonvoice": CV_BASE_CONFIG, - "commonvoice-en": CV_EN_CONFIG, - "commonvoice-ar": CV_AR_CONFIG, - "commonvoice-de": CV_DE_CONFIG, - "commonvoice-es": CV_ES_CONFIG, - "commonvoice-fr": CV_FR_CONFIG, - "commonvoice-it": CV_IT_CONFIG, - "commonvoice-ja": CV_JA_CONFIG, - "commonvoice-pt": CV_PT_CONFIG, - "commonvoice-ru": CV_RU_CONFIG, - "commonvoice-en-transcription": CV_EN_TRANS_CONFIG, - "commonvoice-ar-transcription": CV_AR_TRANS_CONFIG, - "commonvoice-de-transcription": CV_DE_TRANS_CONFIG, - "commonvoice-es-transcription": CV_ES_TRANS_CONFIG, - "commonvoice-fr-transcription": CV_FR_TRANS_CONFIG, - "commonvoice-it-transcription": CV_IT_TRANS_CONFIG, - "commonvoice-ja-transcription": CV_JA_TRANS_CONFIG, - "commonvoice-pt-transcription": CV_PT_TRANS_CONFIG, - "commonvoice-ru-transcription": CV_RU_TRANS_CONFIG, - "commonvoice-en-continuation": CV_EN_CONT_CONFIG, - "commonvoice-ar-continuation": CV_AR_CONT_CONFIG, - "commonvoice-de-continuation": CV_DE_CONT_CONFIG, - "commonvoice-es-continuation": CV_ES_CONT_CONFIG, - "commonvoice-fr-continuation": CV_FR_CONT_CONFIG, - "commonvoice-it-continuation": CV_IT_CONT_CONFIG, - "commonvoice-ja-continuation": CV_JA_CONT_CONFIG, - "commonvoice-pt-continuation": CV_PT_CONT_CONFIG, - "commonvoice-ru-continuation": CV_RU_CONT_CONFIG, - "gigaspeech": GS_XL_CONFIG, - "librispeech": LS_BASE_CONFIG, - "librispeech-clean": LS_CLEAN_CONFIG, - "librispeech-other": LS_OTHER_CONFIG, - "librispeech-clean-transcription": LS_CLEAN_TRANS_CONFIG, - "librispeech-other-transcription": LS_OTHER_TRANS_CONFIG, - "librispeech-clean-continuation": LS_CLEAN_CONT_CONFIG, - "librispeech-other-continuation": LS_OTHER_CONT_CONFIG, - "peoplespeech": PS_CLEAN_CONFIG, - "peoplespeech-transcription": PS_CLEAN_TRANS_CONFIG, - "peoplespeech-continuation": PS_CLEAN_CONT_CONFIG, - "voxpopuli": VP_EN_CONFIG, -} - - -def register_datasets(datasets: Dict[str, DatasetConfig]): - for name, config in datasets.items(): +INTERNAL_DATASETS = [ + BOOLQ_CONFIG, + CV_BASE_CONFIG, + CV_EN_CONFIG, + CV_AR_CONFIG, + CV_DE_CONFIG, + CV_ES_CONFIG, + CV_FR_CONFIG, + CV_IT_CONFIG, + CV_JA_CONFIG, + CV_PT_CONFIG, + CV_RU_CONFIG, + CV_EN_TRANS_CONFIG, + CV_AR_TRANS_CONFIG, + CV_DE_TRANS_CONFIG, + CV_ES_TRANS_CONFIG, + CV_FR_TRANS_CONFIG, + CV_IT_TRANS_CONFIG, + CV_JA_TRANS_CONFIG, + CV_PT_TRANS_CONFIG, + CV_RU_TRANS_CONFIG, + CV_EN_CONT_CONFIG, + CV_AR_CONT_CONFIG, + CV_DE_CONT_CONFIG, + CV_ES_CONT_CONFIG, + CV_FR_CONT_CONFIG, + CV_IT_CONT_CONFIG, + CV_JA_CONT_CONFIG, + CV_PT_CONT_CONFIG, + CV_RU_CONT_CONFIG, + GS_XL_CONFIG, + LS_BASE_CONFIG, + LS_CLEAN_CONFIG, + LS_OTHER_CONFIG, + LS_CLEAN_TRANS_CONFIG, + LS_OTHER_TRANS_CONFIG, + LS_CLEAN_CONT_CONFIG, + LS_OTHER_CONT_CONFIG, + PS_CLEAN_CONFIG, + PS_CLEAN_TRANS_CONFIG, + PS_CLEAN_CONT_CONFIG, + VP_EN_CONFIG, +] +DATASET_MAP: Dict[str, DatasetConfig] = {} + + +def register_datasets(datasets: List[DatasetConfig]): + for config in datasets: + name = config.name + assert name not in DATASET_MAP, f"Dataset {name} already registered" DATASET_MAP[name] = config +def unregister_datasets(datasets: List[str]): + for name in datasets: + del DATASET_MAP[name] + + +def _merge_configs(configs: List[DatasetConfig]) -> DatasetConfig: + merged_config = dataclasses.replace(configs[0]) + for config in configs[1:]: + for field in dataclasses.fields(config): + value = getattr(config, field.name) + if field.name != "base" and value is not None: + merged_config = dataclasses.replace( + merged_config, **{field.name: value} + ) + return merged_config + + def create_dataset(name: str, args: VoiceDatasetArgs) -> SizedIterableDataset: assert name in DATASET_MAP, f"Unknown dataset: {name}" # Make a list of configs from root->base. - configs = [] + configs: List[DatasetConfig] = [] temp: Optional[str] = name while temp: config = DATASET_MAP[temp] configs.insert(0, config) temp = config.base # Set the root config, and then apply any non-None overrides from the subclasses. - merged_config = dataclasses.replace(configs[0]) - for config in configs[1:]: - for field in dataclasses.fields(config): - value = getattr(config, field.name) - if field.name != "base" and value is not None: - merged_config = dataclasses.replace( - merged_config, **{field.name: value} - ) + merged_config = _merge_configs(configs) # Sanity check. + if not merged_config.path: + raise ValueError(f"Dataset {name} has no path") if not merged_config.splits: raise ValueError(f"Dataset {name} has no splits") return GenericDataset(args, merged_config) + + +register_datasets(INTERNAL_DATASETS) diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index 4351f8c0..193453fe 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -213,6 +213,7 @@ def test_transcribe_dataset(): def test_dataset_config(): config = datasets.DatasetConfig( + name="fake_dataset", path="mock_path", splits=[ datasets.DatasetSplitConfig(name="clean", num_samples=5000), @@ -225,6 +226,7 @@ def test_dataset_config(): ), ], ) + assert config.name == "fake_dataset" assert config.path == "mock_path" assert len(config.splits) == 4 assert config.splits[0].name == "clean" @@ -243,7 +245,8 @@ def test_dataset_config(): def test_dataset_config_serialization(): config = datasets.DatasetConfig( - path="mock_path", + name="fake_dataset", + path="fake_path", splits=[ datasets.DatasetSplitConfig(name="clean", num_samples=5000), datasets.DatasetSplitConfig(name="other", num_samples=10000), @@ -252,7 +255,8 @@ def test_dataset_config_serialization(): serialized = config.dumps_yaml() deserialized = datasets.DatasetConfig.loads_yaml(serialized) assert isinstance(deserialized, datasets.DatasetConfig) - assert deserialized.path == "mock_path" + assert deserialized.name == "fake_dataset" + assert deserialized.path == "fake_path" assert len(deserialized.splits) == 2 assert deserialized.splits[0].name == "clean" assert deserialized.splits[0].num_samples == 5000 @@ -261,11 +265,12 @@ def test_dataset_config_serialization(): def test_generic_dataset(): - mock_config = datasets.DatasetConfig( - path="mock_path", + config = datasets.DatasetConfig( + name="fake_dataset", + path="fake_path", splits=[datasets.DatasetSplitConfig(name="fake", num_samples=5)], ) - ds = FakeGenericDataset(5, mock_config) + ds = FakeGenericDataset(5, config) assert len(ds) == 5 sample = next(iter(ds)) assert isinstance(sample, datasets.VoiceSample) @@ -279,14 +284,15 @@ def test_generic_dataset(): def test_generic_dataset_custom_templates(): - mock_config = datasets.DatasetConfig( - path="mock_path", + config = datasets.DatasetConfig( + name="fake_dataset", + path="fake_path", splits=[datasets.DatasetSplitConfig(name="fake", num_samples=5)], user_template="Listen to the following and respond with 'xyzzy':\n<|audio|>", assistant_template="xyzzy", transcript_template="{{text}}", ) - ds = FakeGenericDataset(5, mock_config) + ds = FakeGenericDataset(5, config) assert len(ds) == 5 sample = next(iter(ds)) assert isinstance(sample, datasets.VoiceSample) @@ -302,23 +308,57 @@ def test_generic_dataset_custom_templates(): assert sample.audio_transcript == "0" +def test_generic_dataset_merge_configs(): + base_config = datasets.DatasetConfig( + name="fake_base", + path="fake_path", + splits=[datasets.DatasetSplitConfig(name="fake", num_samples=5)], + ) + mid_config = datasets.DatasetConfig( + name="fake_mid", + base="fake_base", + user_template="fake_user_template", + user_template_args={"a": 1}, + transcript_template="fake_transcript_template", + ) + leaf_config = datasets.DatasetConfig( + name="fake_leaf", + base="fake_mid", + audio_field="fake_audio_field", + ) + config = datasets._merge_configs([base_config, mid_config, leaf_config]) + assert config.name == "fake_leaf" + assert config.base is None + assert config.path == "fake_path" + assert config.splits[0].name == "fake" + assert config.splits[0].num_samples == 5 + assert config.splits[0].split_type == datasets.DatasetSplit.TRAIN + assert config.user_template == "fake_user_template" + assert config.user_template_args == {"a": 1} + assert config.assistant_template == "{{text}}" # the default + assert config.transcript_template == "fake_transcript_template" + assert config.audio_field == "fake_audio_field" + + def test_generic_dataset_length_mismatch(): - mock_config = datasets.DatasetConfig( - path="mock_path", + config = datasets.DatasetConfig( + name="fake_dataset", + path="fake_path", splits=[datasets.DatasetSplitConfig(name="fake", num_samples=5)], ) - ds = FakeGenericDataset(10, mock_config) + ds = FakeGenericDataset(10, config) assert len(ds) == 5 pattern = r"(has been exceeded|Mismatch between presumed length)" with pytest.warns(UserWarning, match=pattern): list(ds) - mock_config = datasets.DatasetConfig( - path="mock_path", + config = datasets.DatasetConfig( + name="fake_dataset", + path="fake_path", splits=[datasets.DatasetSplitConfig(name="fake", num_samples=10)], ) - ds = FakeGenericDataset(5, mock_config) + ds = FakeGenericDataset(5, config) assert len(ds) == 10 with pytest.warns(UserWarning, match="Mismatch between presumed length"): @@ -326,19 +366,18 @@ def test_generic_dataset_length_mismatch(): def test_generic_dataset_multiple_splits(): - mock_config = datasets.DatasetConfig( - path="mock_path", + config = datasets.DatasetConfig( + name="fake_dataset", + path="fake_path", splits=[ datasets.DatasetSplitConfig(name="train", num_samples=90), datasets.DatasetSplitConfig(name="validation", num_samples=10), ], ) - ds = FakeGenericDataset(100, mock_config) + ds = FakeGenericDataset(100, config) assert len(ds) == 90 ds = FakeGenericDataset( - 100, - mock_config, - datasets.VoiceDatasetArgs(split=datasets.DatasetSplit.VALIDATION), + 100, config, datasets.VoiceDatasetArgs(split=datasets.DatasetSplit.VALIDATION) ) assert len(ds) == 10 diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index fbdf5116..c730f208 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -5,7 +5,7 @@ import re import sys from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import simple_parsing import torch @@ -14,6 +14,13 @@ from ultravox.model import ultravox_config +@dataclasses.dataclass +class DatasetOptions: + name: str + weight: float = 1.0 + include_audio: bool = True + + @dataclasses.dataclass class TrainConfig: # Language model to use @@ -21,14 +28,24 @@ class TrainConfig: # Audio encoder model to use audio_model: str - # data-defined datasets - data_sets: Dict[str, datasets.DatasetConfig] = dataclasses.field( - default_factory=dict - ) - # training sets and weights - train_sets: Dict[str, float] = dataclasses.field(default_factory=dict) - # validation sets and weights - val_sets: Dict[str, float] = dataclasses.field(default_factory=dict) + # Workaround for simple_parsing not supporting lists of dataclasses; we need to + # define these as lists of dicts and convert them manually in helpers. + + # Data-defined datasets (datasets.DatasetConfig) + data_sets: List[Dict[str, Any]] = simple_parsing.list_field() + # Training sets and weights (DatasetOptions) + train_sets: List[Dict[str, Any]] = simple_parsing.list_field() + # Validation sets and weights (DatasetOptions) + val_sets: List[Dict[str, Any]] = simple_parsing.list_field() + + def get_data_sets(self) -> List[datasets.DatasetConfig]: + return [datasets.DatasetConfig.from_dict(ds) for ds in self.data_sets] + + def get_train_sets(self) -> List[DatasetOptions]: + return [DatasetOptions(**ds) for ds in self.train_sets] + + def get_val_sets(self) -> List[DatasetOptions]: + return [DatasetOptions(**ds) for ds in self.val_sets] do_train: bool = True do_eval: bool = True @@ -146,7 +163,9 @@ def fix_hyphens(arg: str): return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg) -def get_train_args(override_sys_args: Optional[List[str]] = None) -> TrainConfig: +def get_train_args( + override_sys_args: Optional[List[str]] = None, config_file="meta_config.yaml" +) -> TrainConfig: """ Parse the command line arguments and return a TrainConfig object. @@ -158,7 +177,7 @@ def get_train_args(override_sys_args: Optional[List[str]] = None) -> TrainConfig return simple_parsing.parse( config_class=TrainConfig, - config_path=os.path.join(os.path.dirname(__file__), "configs/meta_config.yaml"), + config_path=os.path.join(os.path.dirname(__file__), "configs", config_file), add_config_path_arg=True, args=[fix_hyphens(arg) for arg in args], ) diff --git a/ultravox/training/configs/llama_whisper.yaml b/ultravox/training/configs/llama_whisper.yaml deleted file mode 100644 index 1c007d2d..00000000 --- a/ultravox/training/configs/llama_whisper.yaml +++ /dev/null @@ -1,103 +0,0 @@ -# llama3.1-8b + whisper-medium, for development - -exp_name: "llama3.1-8b-whisper" -text_model: "meta-llama/Meta-Llama-3.1-8B-Instruct" -audio_model: "openai/whisper-medium" - -loss_config: - loss_function: "KL_Divergence" - -max_steps: 20 # x8x24 = 2,764,800 - -# This would go in a datasets.yaml file and we could either use pyyaml-include to include it -# or we could just add this logic to the training script. This file can also include its own datasets -# key, with locally defined datasets. -data_sets: - librispeech: - path: "fixie-ai/librispeech_asr" - user_template: "<|audio|>" - assistant_template: "" - transcript_template: "{{ text }}" - - # Note the inheritance here - librispeech-clean: - base: "librispeech" # this could also be done via "<<": *librispeech, although that approach is less flexible - subset: "clean" - splits: - - name: "train.100" # 28_539 samples - num_samples: 28_539 - - name: "train.360" # 104_014 samples - num_samples: 104_014 - - librispeech-other: - base: "librispeech" - subset: "other" - splits: - - name: "train.500" # 148_688 samples - num_samples: 148_688 - - covost2: - path: "fixie-ai/covost2" - user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" - assistant_template: "{{ translation }}" - transcript_template: "{{ sentence }}" - - # Note the inheritance here - covost2-es_en: - base: "covost2" - subset: "es_en" - splits: - - name: "train" - num_samples: 100000 - - name: "validation" - num_samples: 15531 - user_template_args: - target: "English" - - covost2-en_zh: - base: "covost2" - subset: "en_zh-CN" - splits: - - name: "train" - num_samples: 100000 - - name: "validation" - num_samples: 15531 - user_template_args: - target: "Chinese" - - covost-foo: - base: "covost2" - subset: "foo" - splits: - - name: "eval" - num_samples: 22222 - is_validation: true - - covost-bar: - base: "covost2" - subset: "foo" - splits: - - name: "eval2" - num_samples: 22222 - is_validation: true - - covost-small: - base: "covost2" - subset: "foo" - splits: - - name: "eval3" - num_samples: 22 - is_validation: true - -# This is the new approach to weighting, which keeps this out of the dataset config -train_sets: - librispeech-clean: 0.5 - librispeech-other: 2.0 - covost2-es_en: 1.0 - covost2-en_zh: 1.0 - -val_sets: - covost2-es_en: 1.0 - covost2-en_zh: 1.0 - covost-foo: 1.0 - covost-bar: 1.0 diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index 274d4aed..2f15df14 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -2,12 +2,10 @@ text_model: "meta-llama/Meta-Llama-3-8B-Instruct" audio_model: "facebook/wav2vec2-base-960h" train_sets: - gigaspeech: 1.0 + - name: gigaspeech val_sets: - heysquad_human: 1.0 - anyinstruct: 1.0 - soda: 1.0 - peoplespeech: 1.0 + - name: gigaspeech + - weight: 0.1 stop_strategy: "LAST_EXHAUSTED" train_on_inputs: False diff --git a/ultravox/training/configs/release_config.yaml b/ultravox/training/configs/release_config.yaml index 36b7a5f6..02928fdb 100644 --- a/ultravox/training/configs/release_config.yaml +++ b/ultravox/training/configs/release_config.yaml @@ -5,212 +5,58 @@ exp_name: "ultravox-v0_4" text_model: "meta-llama/Meta-Llama-3.1-8B-Instruct" audio_model: "openai/whisper-medium" - loss_config: # Choose from ["KL_Divergence", "CrossEntropy"], default is "KL_Divergence" loss_function: "KL_Divergence" -# Temporarily remove heysquad_human from val_sets as it causes the training to fail. -val_sets: ["anyinstruct", "soda", "peoplespeech"] - -batch_size: 24 -max_steps: 14400 # x8x24 = 2,764,800 - -data_sets: ["anyinstruct"] -data_dicts: -# continuation - - path: "fixie-ai/librispeech_asr" - name: "clean" - splits: - - "train.100" # 28_539 samples - - "train.360" # 104_014 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text }}" - weight: 1 - - path: "fixie-ai/librispeech_asr" - name: "other" - splits: - - "train.500" # 148_688 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text }}" - weight: 1 - - path: "fixie-ai/peoples_speech" - name: "clean" - splits: - - "train" # 1_501_271 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text_proc.format_asr_text(text) }}" +train_sets: + - name: librispeech-clean-continuation + - name: librispeech-other-continuation + - name: peoplespeech-continuation weight: 8 - - path: "fixie-ai/common_voice_17_0" - name: "en" - splits: - - "train" # 1_101_170 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text_proc.format_asr_text(sentence) }}" + - name: common-voice-en-continuation weight: 8 - - path: "fixie-ai/common_voice_17_0" - name: "ar" - splits: - - "train" # 28_369 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" + - name: common-voice-ar-continuation weight: 0.2 - - path: "fixie-ai/common_voice_17_0" - name: "de" - splits: - - "train" # 589_100 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" + - name: common-voice-de-continuation weight: 4 - - path: "fixie-ai/common_voice_17_0" - name: "es" - splits: - - "train" # 336_846 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" + - name: common-voice-es-continuation weight: 3 - - path: "fixie-ai/common_voice_17_0" - name: "fr" - splits: - - "train" # 558_054 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" + - name: common-voice-fr-continuation weight: 4 - - path: "fixie-ai/common_voice_17_0" - name: "it" - splits: - - "train" # 169_771 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" + - name: common-voice-it-continuation weight: 1.2 - - path: "fixie-ai/common_voice_17_0" - name: "ja" - splits: - - "train" # 10_039 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" + - name: common-voice-ja-continuation weight: 0.1 - - path: "fixie-ai/common_voice_17_0" - name: "pt" - splits: - - "train" # 21_968 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" + - name: common-voice-pt-continuation weight: 0.2 - - path: "fixie-ai/common_voice_17_0" - name: "ru" - splits: - - "train" # 26_377 samples - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ sentence }}" + - name: common-voice-ru-continuation weight: 0.2 -# ASR task - - path: "fixie-ai/librispeech_asr" - name: "clean" - splits: - - "train.100" # 28_539 samples - - "train.360" # 104_014 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text }}" - transcript_template: "{{ text }}" - weight: 0.1 - - path: "fixie-ai/librispeech_asr" - name: "other" - splits: - - "train.500" # 148_688 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text }}" - transcript_template: "{{ text }}" - weight: 0.1 - - path: "fixie-ai/peoples_speech" - name: "clean" - splits: - - "train" # 1_501_271 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(text) }}" - transcript_template: "{{ text_proc.format_asr_text(text) }}" + - name: librispeech-clean-transcription + - name: librispeech-other-transcription + - name: peoplespeech-transcription weight: 0.8 - - path: "fixie-ai/common_voice_17_0" - name: "en" - splits: - - "train" # 1_101_170 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ text_proc.format_asr_text(sentence) }}" + - name: common-voice-en-transcription weight: 0.8 - - path: "fixie-ai/common_voice_17_0" - name: "ar" - splits: - - "train" # 28_369 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" + - name: common-voice-ar-transcription weight: 0.02 - - path: "fixie-ai/common_voice_17_0" - name: "de" - splits: - - "train" # 589_100 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" + - name: common-voice-de-transcription weight: 0.4 - - path: "fixie-ai/common_voice_17_0" - name: "es" - splits: - - "train" # 336_846 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" + - name: common-voice-es-transcription weight: 0.3 - - path: "fixie-ai/common_voice_17_0" - name: "fr" - splits: - - "train" # 558_054 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" + - name: common-voice-fr-transcription weight: 0.4 - - path: "fixie-ai/common_voice_17_0" - name: "it" - splits: - - "train" # 169_771 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" + - name: common-voice-it-transcription weight: 0.12 - - path: "fixie-ai/common_voice_17_0" - name: "ja" - splits: - - "train" # 10_039 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" + - name: common-voice-ja-transcription weight: 0.01 - - path: "fixie-ai/common_voice_17_0" - name: "pt" - splits: - - "train" # 21_968 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" + - name: common-voice-pt-transcription weight: 0.02 - - path: "fixie-ai/common_voice_17_0" - name: "ru" - splits: - - "train" # 26_377 samples - user_template: "{{ dataset._get_transcribe_prompt() }}" - assistant_template: "{{ text_proc.format_asr_text(sentence) }}" - transcript_template: "{{ sentence }}" + - name: common-voice-ru-transcription weight: 0.02 + +# Temporarily remove heysquad_human from val_sets as it causes the training to fail. +val_sets: ["peoplespeech"] + +batch_size: 24 +max_steps: 14400 # x8x24 = 2,764,800 diff --git a/ultravox/training/train.py b/ultravox/training/train.py index f33ceab0..a6a51acd 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -6,7 +6,7 @@ import os import subprocess from datetime import datetime -from typing import Dict, Optional +from typing import List, Optional import accelerate import datasets as hf_datasets @@ -29,13 +29,10 @@ from ultravox.training import ddp_utils from ultravox.training.helpers import prefetch_weights -INPUT_EXAMPLE = {"text": "Transcribe\n<|audio|>", "audio": b"\x00\x00" * 16000} -OUTPUT_EXAMPLE = {"text": "Hello, world!"} - def prepare_dataset( train_args: config_base.TrainConfig, - data_sets_and_weights: Dict[str, float], + data_opts: List[config_base.DatasetOptions], data_args: datasets.VoiceDatasetArgs, processor: ultravox_processing.UltravoxProcessor, train_on_inputs: bool, @@ -43,8 +40,8 @@ def prepare_dataset( num_samples: Optional[int] = None, include_alt_fields: bool = False, # whether to generate tensors for text-only input (e.g., used for KD training) ) -> datasets.SizedIterableDataset: - data_names = list(data_sets_and_weights.keys()) - data_weights = list(data_sets_and_weights.values()) + data_names = [ds.name for ds in data_opts] + data_weights = [ds.weight for ds in data_opts] data_sets = [datasets.create_dataset(ds, data_args) for ds in data_names] # If we're using epochs to train, validate the dataset length is appropriate. if train_args.max_steps == 0: @@ -200,14 +197,15 @@ def train(args: config_base.TrainConfig): logging.info(f"Using device (world_size): {model.device} ({world_size})") # Register custom datasets - datasets.register_datasets(args.data_sets) + datasets.register_datasets(args.get_data_sets()) # Prepare dataset, subsetting if needed train_dataset: datasets.SizedIterableDataset val_dataset: datasets.SizedIterableDataset + train_dataset = prepare_dataset( train_args=args, - data_sets_and_weights=args.train_sets, + data_opts=args.get_train_sets(), train_on_inputs=args.train_on_inputs, stop_strategy=args.stop_strategy, processor=processor, @@ -227,7 +225,7 @@ def train(args: config_base.TrainConfig): ) val_dataset = prepare_dataset( train_args=args, - data_sets_and_weights=args.val_sets, + data_opts=args.get_val_sets(), train_on_inputs=args.train_on_inputs, stop_strategy=args.stop_strategy, processor=processor, From 8a9fb9b486e7e46281305ad9cb9bbcd46d8a1270 Mon Sep 17 00:00:00 2001 From: juberti Date: Thu, 17 Oct 2024 15:37:42 -0700 Subject: [PATCH 08/17] text-only --- ultravox/data/datasets.py | 55 ++++++++++++++++++++++++++++------ ultravox/data/datasets_test.py | 18 +++++++++++ 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index b3015a2b..ba82c860 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -24,6 +24,7 @@ from ultravox.data import text_proc +AUDIO_PLACEHOLDER = "<|audio|>" SAMPLE_RATE = 16000 # TODO(juberti): set these in the environment so they don't need to be hard-coded here. @@ -36,7 +37,6 @@ class DatasetSplit(str, enum.Enum): TRAIN = "train" - TEST = "test" VALIDATION = "validation" @@ -81,6 +81,9 @@ def __post_init__(self): @dataclasses.dataclass class DatasetConfig(helpers.Serializable): + # Note that subclasses can override any of these fields, but they currently can't + # extend structured fields like splits or user_template_args. + # See _merge_configs below for the current implementation. name: str """Name of the dataset.""" base: Optional[str] = None @@ -110,7 +113,7 @@ def __post_init__(self): """Set defaults only if this is a root config, so that said defaults in a subclass don't act as overrides.""" DEFAULTS = { "splits": [], - "user_template": "<|audio|>", + "user_template": AUDIO_PLACEHOLDER, "user_template_args": {}, "assistant_template": "{{text}}", "transcript_template": "{{text}}", @@ -486,7 +489,6 @@ def _get_sample(self, row) -> Optional[VoiceSample]: **row, text_proc=text_proc, dataset=self, - include_audio=self._args.include_audio, **self._config.user_template_args, ) assistant_content = jinja2.Template( @@ -504,11 +506,31 @@ def _get_sample(self, row) -> Optional[VoiceSample]: raise ValueError( "Template rendering failed. Make sure all keys in the template exist in the sample." ) from e + if not self._args.include_audio: + user_content = user_content.replace(AUDIO_PLACEHOLDER, f'"{transcript}"') + messages = _get_messages(user_content, assistant_content) + audio = self._get_audio(row, self._config.audio_field) + return self._make_sample(messages, audio, audio_transcript=transcript) + + +class LibriSpeechDummyDataset(VoiceDataset): + def __init__(self, args: VoiceDatasetArgs) -> None: + super().__init__(args) + # This dataset doesn't support streaming. + dataset = self._load_hf_dataset( + "hf-internal-testing/librispeech_asr_dummy", + "clean", + split="validation", + streaming=False, + ) + self._init_dataset(dataset, 73) + def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: + text = text_proc.format_asr_text(row["text"]) return self._make_sample( - _get_messages(user_content, assistant_content), - self._get_audio(row, self._config.audio_field), - audio_transcript=transcript, + self._make_messages(f"Transcribe\n{AUDIO_PLACEHOLDER}", text), + self._get_audio(row, "audio"), + audio_transcript=text, ) @@ -635,10 +657,10 @@ def __len__(self): CONTINUATION_USER_TEMPLATE = ( - "Continue the following text using less than 50 words:\n\n<|audio|>" + f"Continue the following text using less than 50 words:\n\n{AUDIO_PLACEHOLDER}" ) CONTINUATION_ASSISTANT_TEMPLATE = "{{continuation}}" -TRANSCRIPTION_USER_TEMPLATE = "Transcribe\n<|audio|>" +TRANSCRIPTION_USER_TEMPLATE = f"Transcribe\n{AUDIO_PLACEHOLDER}" BOOLQ_CONFIG = DatasetConfig( name="boolq", @@ -647,7 +669,7 @@ def __len__(self): DatasetSplitConfig(name="train", num_samples=10000), DatasetSplitConfig(name="validation", num_samples=1000), ], - user_template="{{passage}}\n\n{{'<|audio|>' if include_audio else question}}", + user_template="{{passage}}\n\n{AUDIO_PLACEHOLDER}", assistant_template="{{'True' if answer else 'False'}}", transcript_template="{{question}}", ) @@ -770,6 +792,19 @@ def __len__(self): ], ) +# SODA_CONFIG = DatasetConfig( +# name="soda", +# path="fixie-ai/soda-audio", +# splits=[ +# DatasetSplitConfig(name="train", num_samples=1_000_000), +# DatasetSplitConfig(name="validation", num_samples=10_000), +# ], +# # Need way to specify message history. +# audio_field="audio_second_last_turn", +# assistant_template="{{alt_last_turn}}", +# transcript_template="{{turns[-2]}}", +# ) + VP_EN_CONFIG = DatasetConfig( name="voxpopuli-en", path="facebook/voxpopuli", @@ -991,6 +1026,8 @@ def _merge_configs(configs: List[DatasetConfig]) -> DatasetConfig: def create_dataset(name: str, args: VoiceDatasetArgs) -> SizedIterableDataset: + if name == "dummy": + return LibriSpeechDummyDataset(args) assert name in DATASET_MAP, f"Unknown dataset: {name}" # Make a list of configs from root->base. configs: List[DatasetConfig] = [] diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index 193453fe..245aac31 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -308,6 +308,24 @@ def test_generic_dataset_custom_templates(): assert sample.audio_transcript == "0" +def test_generic_dataset_text_only(): + config = datasets.DatasetConfig( + name="fake_dataset", + path="fake_path", + splits=[datasets.DatasetSplitConfig(name="fake", num_samples=5)], + user_template="Transcribe\n<|audio|>", + ) + ds = FakeGenericDataset(5, config, datasets.VoiceDatasetArgs(include_audio=False)) + assert len(ds) == 5 + sample = next(iter(ds)) + assert isinstance(sample, datasets.VoiceSample) + assert sample.messages == [ + {"role": "user", "content": 'Transcribe\n"0"'}, + {"role": "assistant", "content": "0"}, + ] + assert sample.audio is None + + def test_generic_dataset_merge_configs(): base_config = datasets.DatasetConfig( name="fake_base", From 0f983af7a413665b58b39a19831c72046bef17ac Mon Sep 17 00:00:00 2001 From: juberti Date: Thu, 17 Oct 2024 15:58:25 -0700 Subject: [PATCH 09/17] fix bugs --- ultravox/data/datasets.py | 4 +++- ultravox/tools/infer_api.py | 8 +++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index ba82c860..0968fc60 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -527,8 +527,10 @@ def __init__(self, args: VoiceDatasetArgs) -> None: def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: text = text_proc.format_asr_text(row["text"]) + user_content = "Transcribe\n" + user_content += AUDIO_PLACEHOLDER if self._args.include_audio else f'"{text}"' return self._make_sample( - self._make_messages(f"Transcribe\n{AUDIO_PLACEHOLDER}", text), + self._make_messages(user_content, text), self._get_audio(row, "audio"), audio_transcript=text, ) diff --git a/ultravox/tools/infer_api.py b/ultravox/tools/infer_api.py index a1d217f6..d4425cdf 100644 --- a/ultravox/tools/infer_api.py +++ b/ultravox/tools/infer_api.py @@ -57,16 +57,22 @@ def infer_stream( data["temperature"] = temperature response = requests.post(url, headers=headers, json=data, stream=True) response.raise_for_status() + num_tokens = 0 + got_stats = False for line in response.iter_lines(): event = line[6:].decode("utf-8") if event and event[0] == "{": obj = json.loads(event) if obj.get("choices") and obj["choices"][0]["delta"].get("content"): + num_tokens += 1 yield base.InferenceChunk(obj["choices"][0]["delta"]["content"]) if obj.get("usage"): + got_stats = True yield base.InferenceStats( obj["usage"]["prompt_tokens"], obj["usage"]["completion_tokens"] ) + if not got_stats: + yield base.InferenceStats(-1, num_tokens) def _build_messages(self, sample: datasets.VoiceSample): """ @@ -77,7 +83,7 @@ def _build_messages(self, sample: datasets.VoiceSample): Audio is converted to a data URI and inserted into the message under an image_url type. """ if sample.audio is None: - return sample + return sample.messages fragments = sample.messages[-1]["content"].split("<|audio|>") assert len(fragments) == 2, "Expected one <|audio|> placeholder" From 9bd8e4bda7345ed6705f0f001f9173ec4d8f3525 Mon Sep 17 00:00:00 2001 From: juberti Date: Thu, 17 Oct 2024 16:02:12 -0700 Subject: [PATCH 10/17] Fix bug --- ultravox/data/datasets.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 0968fc60..51fdd4ed 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -474,6 +474,9 @@ def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: ) dsets.append(ds) total_samples += split.num_samples + assert ( + len(dsets) > 0 + ), f"The {config.name} dataset has no {self._args.split} splits." dataset = ds if len(dsets) == 1 else hf_datasets.concatenate_datasets(dsets) super()._init_dataset(dataset, total_samples) From e29f0d190ff682348f08499e6b2fd5b5c1d7e88b Mon Sep 17 00:00:00 2001 From: juberti Date: Thu, 17 Oct 2024 16:29:54 -0700 Subject: [PATCH 11/17] Fix bugs --- ultravox/data/datasets.py | 4 ++-- ultravox/training/configs/meta_config.yaml | 2 +- ultravox/training/configs/release_config.yaml | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 51fdd4ed..d86436a4 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -401,11 +401,11 @@ def __iter__(self): actual_length += 1 if actual_length == len(self) + 1: warnings.warn( - f"The presumed length {self._length} has been exceeded for split {self._dataset.split}. Make sure to update." + f"The presumed length {self._length} has been exceeded for {self._config.name}:{self._args.split}. Make sure to update." ) if actual_length != len(self): warnings.warn( - f"Mismatch between presumed length ({self._length}) and actual length ({actual_length}) for split {self._dataset.split}. Make sure to update." + f"Mismatch between presumed length ({self._length}) and actual length ({actual_length}) for {self._config.name}:{self._args.split}. Make sure to update." ) @abc.abstractmethod diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index 2f15df14..515b5a1c 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -5,7 +5,7 @@ train_sets: - name: gigaspeech val_sets: - name: gigaspeech - - weight: 0.1 + - weight: 0.01 stop_strategy: "LAST_EXHAUSTED" train_on_inputs: False diff --git a/ultravox/training/configs/release_config.yaml b/ultravox/training/configs/release_config.yaml index 02928fdb..776e238e 100644 --- a/ultravox/training/configs/release_config.yaml +++ b/ultravox/training/configs/release_config.yaml @@ -12,7 +12,7 @@ loss_config: train_sets: - name: librispeech-clean-continuation - name: librispeech-other-continuation - - name: peoplespeech-continuation + - name: peoplespeech-clean-continuation weight: 8 - name: common-voice-en-continuation weight: 8 @@ -56,7 +56,7 @@ train_sets: weight: 0.02 # Temporarily remove heysquad_human from val_sets as it causes the training to fail. -val_sets: ["peoplespeech"] +val_sets: ["peoplespeech-clean"] batch_size: 24 max_steps: 14400 # x8x24 = 2,764,800 From 495ab8234e4a51b78a9862808875ab33387a53b1 Mon Sep 17 00:00:00 2001 From: juberti Date: Thu, 17 Oct 2024 16:41:11 -0700 Subject: [PATCH 12/17] config updates --- ultravox/training/configs/release_config.yaml | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/ultravox/training/configs/release_config.yaml b/ultravox/training/configs/release_config.yaml index 776e238e..6ebdd987 100644 --- a/ultravox/training/configs/release_config.yaml +++ b/ultravox/training/configs/release_config.yaml @@ -14,49 +14,50 @@ train_sets: - name: librispeech-other-continuation - name: peoplespeech-clean-continuation weight: 8 - - name: common-voice-en-continuation + - name: commonvoice-en-continuation weight: 8 - - name: common-voice-ar-continuation + - name: commonvoice-ar-continuation weight: 0.2 - - name: common-voice-de-continuation + - name: commonvoice-de-continuation weight: 4 - - name: common-voice-es-continuation + - name: commonvoice-es-continuation weight: 3 - - name: common-voice-fr-continuation + - name: commonvoice-fr-continuation weight: 4 - - name: common-voice-it-continuation + - name: commonvoice-it-continuation weight: 1.2 - - name: common-voice-ja-continuation + - name: commonvoice-ja-continuation weight: 0.1 - - name: common-voice-pt-continuation + - name: commonvoice-pt-continuation weight: 0.2 - - name: common-voice-ru-continuation + - name: commonvoice-ru-continuation weight: 0.2 - name: librispeech-clean-transcription - name: librispeech-other-transcription - - name: peoplespeech-transcription + - name: peoplespeech-clean-transcription weight: 0.8 - - name: common-voice-en-transcription + - name: commonvoice-en-transcription weight: 0.8 - - name: common-voice-ar-transcription + - name: commonvoice-ar-transcription weight: 0.02 - - name: common-voice-de-transcription + - name: commonvoice-de-transcription weight: 0.4 - - name: common-voice-es-transcription + - name: commonvoice-es-transcription weight: 0.3 - - name: common-voice-fr-transcription + - name: commonvoice-fr-transcription weight: 0.4 - - name: common-voice-it-transcription + - name: commonvoice-it-transcription weight: 0.12 - - name: common-voice-ja-transcription + - name: commonvoice-ja-transcription weight: 0.01 - - name: common-voice-pt-transcription + - name: commonvoice-pt-transcription weight: 0.02 - - name: common-voice-ru-transcription + - name: commonvoice-ru-transcription weight: 0.02 # Temporarily remove heysquad_human from val_sets as it causes the training to fail. -val_sets: ["peoplespeech-clean"] +val_sets: + - name: peoplespeech batch_size: 24 max_steps: 14400 # x8x24 = 2,764,800 From 73956817ebb86d703fba59530582172c83f5d9bd Mon Sep 17 00:00:00 2001 From: juberti Date: Thu, 17 Oct 2024 17:04:30 -0700 Subject: [PATCH 13/17] val_sets -> dict --- ultravox/training/train.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/ultravox/training/train.py b/ultravox/training/train.py index a6a51acd..8dd3628d 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -6,7 +6,7 @@ import os import subprocess from datetime import datetime -from typing import List, Optional +from typing import Dict, List, Optional import accelerate import datasets as hf_datasets @@ -201,7 +201,7 @@ def train(args: config_base.TrainConfig): # Prepare dataset, subsetting if needed train_dataset: datasets.SizedIterableDataset - val_dataset: datasets.SizedIterableDataset + val_datasets: Dict[str, datasets.SizedIterableDataset] train_dataset = prepare_dataset( train_args=args, @@ -223,16 +223,19 @@ def train(args: config_base.TrainConfig): shuffle=False, max_audio_duration_secs=16, ) - val_dataset = prepare_dataset( - train_args=args, - data_opts=args.get_val_sets(), - train_on_inputs=args.train_on_inputs, - stop_strategy=args.stop_strategy, - processor=processor, - num_samples=args.val_num_samples, - data_args=val_ds_args, - include_alt_fields=model.loss_config.requires_alt_fields, - ) + val_datasets = {} + for val_opt in args.get_val_sets(): + val_dataset = prepare_dataset( + train_args=args, + data_opts=[val_opt], + train_on_inputs=args.train_on_inputs, + stop_strategy=args.stop_strategy, + processor=processor, + num_samples=args.val_num_samples, + data_args=val_ds_args, + include_alt_fields=model.loss_config.requires_alt_fields, + ) + val_datasets[val_opt.name] = val_dataset logging.info( f"Loaded {len(args.train_sets)}) data sets, sample limit: {args.num_samples} (val sample limit: {args.val_num_samples})" ) @@ -241,7 +244,10 @@ def train(args: config_base.TrainConfig): # The point of this is to avoid unnecessary data processing/downloading in the workers. # When using epochs to train, emptydataset must have a length equal to the training set train_dataset = datasets.EmptyDataset(len(train_dataset)) - val_dataset = datasets.EmptyDataset(len(val_dataset)) + val_datasets = { + val_set_name: datasets.EmptyDataset() + for val_set_name, val_dataset in val_datasets.items() + } # Set up the data loader data_collator = datasets.DataCollatorForSeq2SeqWithAudio( From 8d3ac96253a6e389710fd96e765523d81b09d456 Mon Sep 17 00:00:00 2001 From: juberti Date: Thu, 17 Oct 2024 17:24:25 -0700 Subject: [PATCH 14/17] fixes --- ultravox/data/datasets.py | 4 ++-- ultravox/training/train.py | 9 +++------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index d86436a4..5132e585 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -401,11 +401,11 @@ def __iter__(self): actual_length += 1 if actual_length == len(self) + 1: warnings.warn( - f"The presumed length {self._length} has been exceeded for {self._config.name}:{self._args.split}. Make sure to update." + f"The presumed length {self._length} has been exceeded for {self._config.name}:{self._args.split.value}. Make sure to update." ) if actual_length != len(self): warnings.warn( - f"Mismatch between presumed length ({self._length}) and actual length ({actual_length}) for {self._config.name}:{self._args.split}. Make sure to update." + f"Mismatch between presumed length ({self._length}) and actual length ({actual_length}) for {self._config.name}:{self._args.split.value}. Make sure to update." ) @abc.abstractmethod diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 8dd3628d..5966c511 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -201,7 +201,7 @@ def train(args: config_base.TrainConfig): # Prepare dataset, subsetting if needed train_dataset: datasets.SizedIterableDataset - val_datasets: Dict[str, datasets.SizedIterableDataset] + val_datasets: Dict[str, datasets.SizedIterableDataset] = {} train_dataset = prepare_dataset( train_args=args, @@ -223,7 +223,6 @@ def train(args: config_base.TrainConfig): shuffle=False, max_audio_duration_secs=16, ) - val_datasets = {} for val_opt in args.get_val_sets(): val_dataset = prepare_dataset( train_args=args, @@ -244,10 +243,8 @@ def train(args: config_base.TrainConfig): # The point of this is to avoid unnecessary data processing/downloading in the workers. # When using epochs to train, emptydataset must have a length equal to the training set train_dataset = datasets.EmptyDataset(len(train_dataset)) - val_datasets = { - val_set_name: datasets.EmptyDataset() - for val_set_name, val_dataset in val_datasets.items() - } + for val_opts in args.get_val_sets(): + val_datasets[val_opts.name] = datasets.EmptyDataset() # Set up the data loader data_collator = datasets.DataCollatorForSeq2SeqWithAudio( From f86258c30a54dc6de92d7a9367255ef579f760d2 Mon Sep 17 00:00:00 2001 From: juberti Date: Thu, 17 Oct 2024 17:33:04 -0700 Subject: [PATCH 15/17] d'oh --- ultravox/training/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 5966c511..888ba12d 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -256,7 +256,7 @@ def train(args: config_base.TrainConfig): trainer = transformers.Seq2SeqTrainer( model, train_dataset=train_dataset, - eval_dataset=val_dataset, + eval_dataset=val_datasets, data_collator=data_collator, tokenizer=text_tokenizer, args=transformers.Seq2SeqTrainingArguments( From 9d26dfac90d14e0c9cc3c45c73086a729eb23a13 Mon Sep 17 00:00:00 2001 From: juberti Date: Thu, 17 Oct 2024 19:12:31 -0700 Subject: [PATCH 16/17] v1 --- ultravox/data/datasets.py | 55 ++++++++-------------------------- ultravox/data/datasets_test.py | 14 ++++----- 2 files changed, 19 insertions(+), 50 deletions(-) diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 5132e585..31db98b1 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -550,77 +550,46 @@ def __len__(self): return self._length -class StopStrategy(str, enum.Enum): - FIRST_EXHAUSTED = "FIRST_EXHAUSTED" - LAST_EXHAUSTED = "LAST_EXHAUSTED" - NEVER_STOP = "NEVER_STOP" - - class InterleaveDataset(SizedIterableDataset): - """Interleaves multiple IterableDataset objects based on normalized weights.""" + """Interleaves multiple SizedIterableDataset objects based on normalized weights.""" def __init__( self, datasets: Sequence[SizedIterableDataset], weights: Optional[Sequence[float]] = None, - stop_strategy: StopStrategy = StopStrategy.LAST_EXHAUSTED, seed: Optional[int] = 42, - static: bool = False, ) -> None: """ Args: datasets: A list of SizedIterableDataset objects. - weights: A list of weights for each dataset. - stop_strategy: Strategy for stopping iteration. + weights: An optional list of dataset weights, i.e., the number of times it should be repeated. seed: Optional seed for reproducibility. - static: If true, the datasets are interleaved in a static order with equal weights. """ self._datasets = datasets self._rng = np.random.default_rng(seed) - self._static = static - self._stop_strategy = stop_strategy if weights is None: weights = [1.0] * len(datasets) - total_weight = sum(weights) - self._normalized_probs = [w / total_weight for w in weights] + dataset_samples = [w * len(d) for w, d in zip(weights, datasets)] + self._total_samples = int(sum(dataset_samples)) + self._normalized_probs = [s / self._total_samples for s in dataset_samples] def __iter__(self): - # If no datasets are provided, return an empty iterator - if not self._datasets: - return - + # if the overall position is >= than the individual iterator position, + # we can vend a sample from that iterator and update both positions. + # the overall position is the loop counter. we can also keep an array + # for each dataset position iters = [iter(ds) for ds in self._datasets] - exhausted = [False] * len(iters) - - if self._static: - static_iter = itertools.cycle(range(len(self._datasets))) - - while True: - if self._static: - iter_index = next(static_iter) - else: - iter_index = self._rng.choice(len(iters), p=self._normalized_probs) - + for _ in range(self._total_samples): + iter_index = self._rng.choice(len(iters), p=self._normalized_probs) try: yield next(iters[iter_index]) except StopIteration: - exhausted[iter_index] = True - - # Check if stopping condition is met - if self._stop_strategy == StopStrategy.FIRST_EXHAUSTED or ( - self._stop_strategy == StopStrategy.LAST_EXHAUSTED - and all(exhausted) - ): - break - - # Recreate the iterator if stopping condition is not met and yield the next sample iters[iter_index] = iter(self._datasets[iter_index]) yield next(iters[iter_index]) def __len__(self): - # TODO: Implement the length method for different stop strategies - return sum(len(ds) for ds in self._datasets) + return self._total_samples class Dataproc(SizedIterableDataset): diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index 245aac31..8bab22e5 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -16,7 +16,7 @@ class FakeSizedIterableDataset(datasets.SizedIterableDataset): def __init__(self, n, start=0, length=0): self.data = range(start, start + n) - self._length = length + self._length = length or n def __iter__(self): for sample in self.data: @@ -103,8 +103,8 @@ def test_interleaved_first_exhausted(): ds3 = FakeSizedIterableDataset(3) s = datasets.InterleaveDataset( [ds1, ds2, ds3], - stop_strategy=datasets.StopStrategy.FIRST_EXHAUSTED, - static=True, + # stop_strategy=datasets.StopStrategy.FIRST_EXHAUSTED, + # static=True, ) # static=True disables random sampling of datasets, so the order is deterministic # stop_strategy=first_exhausted will stop interleave when the first dataset is exhausted @@ -118,8 +118,8 @@ def test_interleaved_last_exhausted(): ds2 = FakeSizedIterableDataset(2, start=10) s = datasets.InterleaveDataset( [ds1, ds2], - stop_strategy=datasets.StopStrategy.LAST_EXHAUSTED, - static=True, + # stop_strategy=datasets.StopStrategy.LAST_EXHAUSTED, + # static=True, ) # static=True disables random sampling of datasets, so the order is deterministic # stop_strategy=last_exhausted will stop interleave when the last dataset is exhausted @@ -131,8 +131,8 @@ def test_interleaved_never_stop(): ds2 = FakeSizedIterableDataset(2, start=10) s = datasets.InterleaveDataset( [ds1, ds2], - stop_strategy=datasets.StopStrategy.NEVER_STOP, - static=True, + # stop_strategy=datasets.StopStrategy.NEVER_STOP, + # static=True, ) # static=True disables random sampling of datasets, so the order is deterministic # stop_strategy=never_stop will continue interleaving forever From c27715f7eb3c104e9ddfb3a41b1a89825e205c2f Mon Sep 17 00:00:00 2001 From: juberti Date: Fri, 18 Oct 2024 13:11:27 -0700 Subject: [PATCH 17/17] sr --- ultravox/data/datasets.py | 37 +++++----- ultravox/data/datasets_test.py | 84 +++++----------------- ultravox/training/config_base.py | 5 -- ultravox/training/configs/meta_config.yaml | 1 - ultravox/training/train.py | 7 +- 5 files changed, 38 insertions(+), 96 deletions(-) diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 31db98b1..8a7f3267 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -3,7 +3,6 @@ import dataclasses import enum import io -import itertools import logging import os import tempfile @@ -557,7 +556,6 @@ def __init__( self, datasets: Sequence[SizedIterableDataset], weights: Optional[Sequence[float]] = None, - seed: Optional[int] = 42, ) -> None: """ Args: @@ -566,27 +564,30 @@ def __init__( seed: Optional seed for reproducibility. """ self._datasets = datasets - self._rng = np.random.default_rng(seed) - - if weights is None: + if weights is not None: + assert len(weights) == len(datasets) + else: weights = [1.0] * len(datasets) - dataset_samples = [w * len(d) for w, d in zip(weights, datasets)] - self._total_samples = int(sum(dataset_samples)) - self._normalized_probs = [s / self._total_samples for s in dataset_samples] + self._weighted_samples = [int(w * len(d)) for w, d in zip(weights, datasets)] + self._total_samples = sum(self._weighted_samples) def __iter__(self): - # if the overall position is >= than the individual iterator position, - # we can vend a sample from that iterator and update both positions. - # the overall position is the loop counter. we can also keep an array - # for each dataset position - iters = [iter(ds) for ds in self._datasets] - for _ in range(self._total_samples): - iter_index = self._rng.choice(len(iters), p=self._normalized_probs) + ds_iters = [iter(ds) for ds in self._datasets] + ds_pos = [0] * len(ds_iters) + # Find the iterator that is least far along and vend from it. + for i in range(self._total_samples): + min_fraction = 1.0 + for j in range(len(ds_iters)): + iter_fraction = ds_pos[j] / self._weighted_samples[j] + if iter_fraction < min_fraction: + min_fraction = iter_fraction + iter_index = j try: - yield next(iters[iter_index]) + yield next(ds_iters[iter_index]) except StopIteration: - iters[iter_index] = iter(self._datasets[iter_index]) - yield next(iters[iter_index]) + ds_iters[iter_index] = iter(self._datasets[iter_index]) + yield next(ds_iters[iter_index]) + ds_pos[iter_index] += 1 def __len__(self): return self._total_samples diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index 8bab22e5..36b98770 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -1,4 +1,3 @@ -import itertools from typing import Optional, Union import datasets as hf_datasets @@ -95,90 +94,43 @@ def test_dataproc(): assert list(s) == [0, -1, -2, -3, -4] -def test_interleaved_first_exhausted(): - ds1 = FakeSizedIterableDataset(5) - s = datasets.InterleaveDataset([ds1]) - assert list(s) == [0, 1, 2, 3, 4] - ds2 = FakeSizedIterableDataset(9) - ds3 = FakeSizedIterableDataset(3) - s = datasets.InterleaveDataset( - [ds1, ds2, ds3], - # stop_strategy=datasets.StopStrategy.FIRST_EXHAUSTED, - # static=True, - ) - # static=True disables random sampling of datasets, so the order is deterministic - # stop_strategy=first_exhausted will stop interleave when the first dataset is exhausted - assert list(s) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3] +def test_interleaved_empty(): s = datasets.InterleaveDataset([]) assert list(s) == [] -def test_interleaved_last_exhausted(): +def test_interleaved_single_set(): ds1 = FakeSizedIterableDataset(4) - ds2 = FakeSizedIterableDataset(2, start=10) - s = datasets.InterleaveDataset( - [ds1, ds2], - # stop_strategy=datasets.StopStrategy.LAST_EXHAUSTED, - # static=True, - ) - # static=True disables random sampling of datasets, so the order is deterministic - # stop_strategy=last_exhausted will stop interleave when the last dataset is exhausted - assert list(s) == [0, 10, 1, 11, 2, 10, 3, 11] + s = datasets.InterleaveDataset([ds1]) + assert list(s) == [0, 1, 2, 3] -def test_interleaved_never_stop(): +def test_interleaved_normal_weights(): + ds1 = FakeSizedIterableDataset(4) + ds2 = FakeSizedIterableDataset(8, start=10) + ds3 = FakeSizedIterableDataset(2, start=100) + s = datasets.InterleaveDataset([ds1, ds2, ds3]) + assert list(s) == [0, 10, 100, 11, 1, 12, 13, 2, 14, 101, 15, 3, 16, 17] + + +def test_interleaved_specific_weights(): ds1 = FakeSizedIterableDataset(4) ds2 = FakeSizedIterableDataset(2, start=10) - s = datasets.InterleaveDataset( - [ds1, ds2], - # stop_strategy=datasets.StopStrategy.NEVER_STOP, - # static=True, - ) - # static=True disables random sampling of datasets, so the order is deterministic - # stop_strategy=never_stop will continue interleaving forever - assert list(itertools.islice(s, 12)) == [0, 10, 1, 11, 2, 10, 3, 11, 0, 10, 1, 11] + s = datasets.InterleaveDataset([ds1, ds2], [0.5, 2.0]) + assert list(s) == [0, 10, 11, 1, 10, 11] -def test_interleaved_random(): +def test_interleaved_zero_weights(): ds1 = FakeSizedIterableDataset(4) ds2 = FakeSizedIterableDataset(2, start=10) - s = datasets.InterleaveDataset( - [ds1, ds2], - [10.0, 1.0], - ) - # stop_strategy=last_exhausted will stop interleaving when the last dataset is exhausted (attempted after exhaustion) - assert list(s) == [ - 0, - 1, - 2, - 3, - 0, - 10, - 1, - 2, - 3, - 0, - 1, - 11, - 2, - 3, - 0, - 1, - 2, - 3, - 0, - 1, - 2, - 3, - ] + s = datasets.InterleaveDataset([ds1, ds2], [0.0, 0.0]) + assert list(s) == [] def test_interleaved_with_multiprocessing(): ds = FakeSizedIterableDataset(5) s = datasets.InterleaveDataset([ds]) - dl = data.DataLoader(s, num_workers=1, batch_size=5) - batch = next(iter(dl)) assert torch.allclose(batch, torch.tensor([0, 1, 2, 3, 4])) diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index c730f208..87b5801f 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -50,17 +50,12 @@ def get_val_sets(self) -> List[DatasetOptions]: do_train: bool = True do_eval: bool = True - # In InterleaveDataset, when to stop interleave: choose from last_exhausted (default), first_exhausted, or never_stop - stop_strategy: datasets.StopStrategy = datasets.StopStrategy.LAST_EXHAUSTED - data_dir: Optional[str] = None - mds: bool = False num_samples: Optional[int] = None val_num_samples: int = 100 eval_num_samples: int = 100 eval_max_new_tokens: Optional[int] = None eval_num_procs: int = 8 eval_text_only: bool = False - num_prompts: int = 1 # number of data loader workers num_workers: int = 8 if torch.cuda.is_available() else 1 train_on_inputs: bool = False diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index 515b5a1c..bfe3626b 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -6,7 +6,6 @@ train_sets: val_sets: - name: gigaspeech - weight: 0.01 -stop_strategy: "LAST_EXHAUSTED" train_on_inputs: False shuffle_data: True diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 888ba12d..0eca2527 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -36,7 +36,6 @@ def prepare_dataset( data_args: datasets.VoiceDatasetArgs, processor: ultravox_processing.UltravoxProcessor, train_on_inputs: bool, - stop_strategy: datasets.StopStrategy, num_samples: Optional[int] = None, include_alt_fields: bool = False, # whether to generate tensors for text-only input (e.g., used for KD training) ) -> datasets.SizedIterableDataset: @@ -50,9 +49,7 @@ def prepare_dataset( len(ds) > 1 ), f"Dataset {ds} has length {len(ds)} which is too short for epoch training" - interleave = datasets.InterleaveDataset( - data_sets, data_weights, stop_strategy=stop_strategy - ) + interleave = datasets.InterleaveDataset(data_sets, data_weights) ds_with_proc = data_processing.UltravoxDataproc( interleave, processor=processor, @@ -207,7 +204,6 @@ def train(args: config_base.TrainConfig): train_args=args, data_opts=args.get_train_sets(), train_on_inputs=args.train_on_inputs, - stop_strategy=args.stop_strategy, processor=processor, num_samples=args.num_samples, data_args=datasets.VoiceDatasetArgs( @@ -228,7 +224,6 @@ def train(args: config_base.TrainConfig): train_args=args, data_opts=[val_opt], train_on_inputs=args.train_on_inputs, - stop_strategy=args.stop_strategy, processor=processor, num_samples=args.val_num_samples, data_args=val_ds_args,