diff --git a/scripts/lang_adapt/madx_run_clm.py b/scripts/lang_adapt/madx_run_clm.py index d0c59b6..25f9d27 100644 --- a/scripts/lang_adapt/madx_run_clm.py +++ b/scripts/lang_adapt/madx_run_clm.py @@ -651,7 +651,29 @@ def load_data(data_args, model_args): # # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. - if data_args.dataset_name is not None: + if data_args.dataset_name is not None and data_args.train_file is not None: + # Create a dataset from the provided file + data_files = {} + dataset_args = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + extension = ( + data_args.train_file.split(".")[-1] + if data_args.train_file is not None + else data_args.validation_file.split(".")[-1] + ) + if extension == "txt": + extension = "text" + dataset_args["keep_linebreaks"] = data_args.keep_linebreaks + raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args) + + # Downloading and loading a dataset from the hub. + raw_hf_datasets = load_dataset( + data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir + ) + raw_datasets['train'] = datasets.concatenate_datasets(raw_datasets['train'], raw_hf_datasets['train']).shuffle(seed=42) + + elif data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir @@ -680,7 +702,7 @@ def load_data(data_args, model_args): elif data_args.max_eval_samples is not None : raw_datasets = raw_datasets['train'].train_test_split(test_size = data_args.max_eval_samples) else: - raw_datasets = raw_datasets['train'].train_test_split(test_size = data.args.validation_split_percentage/100.0) + raw_datasets = raw_datasets['train'].train_test_split(test_size = data_args.validation_split_percentage/100.0) raw_datasets['validation'] = raw_datasets['test'] # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html.