diff --git a/src/helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py b/src/helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py index 371e719a9fb..da97f944408 100644 --- a/src/helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +++ b/src/helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py @@ -83,9 +83,9 @@ def __init__(self, perspective: str, data: str, demo_name: str, description: str self.description = TASK_DESCRIPTIONS[description] self.seeds = SEEDS[perspective] - def _get_instances(self, data_path: str, note: str) -> List[Instance]: + def _get_instances(self, data_path: str, note: str, output_path: str) -> List[Instance]: instances: List[Instance] = [] - target_path = os.path.join(self.output_path, data_path) + target_path = os.path.join(output_path, data_path) ensure_directory_exists(os.path.split(target_path)[0]) ensure_file_downloaded(source_url=self.source_url + data_path, target_path=target_path) # to be filled dataset = [] @@ -144,9 +144,9 @@ def get_instances(self, output_path: str) -> List[Instance]: part = self.demo_name asr_path = f"{self.perspective}/{exp}/{part}_asr/{seed}.jsonl" - instances.extend(self._get_instances(asr_path, f"asr_{seed}")) + instances.extend(self._get_instances(asr_path, f"asr_{seed}", output_path)) cacc_path = f"{self.perspective}/{exp}/{part}_cacc/{seed}.jsonl" - instances.extend(self._get_instances(cacc_path, f"cacc_{seed}")) + instances.extend(self._get_instances(cacc_path, f"cacc_{seed}", output_path)) else: if self.perspective == "counterfactual": if self.demo_name.find("cf") != -1: @@ -159,7 +159,7 @@ def get_instances(self, output_path: str) -> List[Instance]: data_path = f"{self.perspective}/{self.data}/{part}/{seed}.jsonl" else: raise ValueError(f"Nonexistent {self.perspective}") - instances.extend(self._get_instances(data_path, str(seed))) + instances.extend(self._get_instances(data_path, str(seed), output_path)) if self.demo_name in ["cf", "zero"]: break return instances diff --git a/src/helm/benchmark/scenarios/decodingtrust_fairness_scenario.py b/src/helm/benchmark/scenarios/decodingtrust_fairness_scenario.py index 429836c92c0..8161dad36b1 100644 --- a/src/helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +++ b/src/helm/benchmark/scenarios/decodingtrust_fairness_scenario.py @@ -40,7 +40,7 @@ def get_references(self, label: int) -> List[Reference]: return references def get_instances(self, output_path: str) -> List[Instance]: - data_path: str = os.path.join(self.output_path, self.sub_scenario) + data_path: str = os.path.join(output_path, self.sub_scenario) url: str = os.path.join(self.source_url, self.sub_scenario) ensure_file_downloaded( diff --git a/src/helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py b/src/helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py index 2d6c20decb7..1a9b584c748 100644 --- a/src/helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +++ b/src/helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py @@ -299,13 +299,13 @@ def get_file_instances( return instances def get_instances(self, output_path: str) -> List[Instance]: - target_test_path = os.path.join(self.output_path, self.test_data_file.replace("/", "_")) + target_test_path = os.path.join(output_path, self.test_data_file.replace("/", "_")) ensure_file_downloaded( source_url=os.path.join(self.source_url, self.test_data_file), target_path=target_test_path, ) - target_train_path = os.path.join(self.output_path, self.train_data_file.replace("/", "_")) + target_train_path = os.path.join(output_path, self.train_data_file.replace("/", "_")) ensure_file_downloaded( source_url=os.path.join(self.source_url, self.train_data_file), target_path=target_train_path, diff --git a/src/helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py b/src/helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py index dcb38fe48d0..5be691c53a9 100644 --- a/src/helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +++ b/src/helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py @@ -164,7 +164,7 @@ def get_prompts(self, datasets): return dataset def get_instances(self, output_path: str) -> List[Instance]: - data_path: str = os.path.join(self.output_path, f"{self.ood_type}.json".replace("/", "_")) + data_path: str = os.path.join(output_path, f"{self.ood_type}.json".replace("/", "_")) ensure_file_downloaded( source_url=self.source_url.format(self.ood_type), diff --git a/src/helm/benchmark/scenarios/decodingtrust_privacy_scenario.py b/src/helm/benchmark/scenarios/decodingtrust_privacy_scenario.py index c2cf49dc567..7835a1787f6 100644 --- a/src/helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +++ b/src/helm/benchmark/scenarios/decodingtrust_privacy_scenario.py @@ -188,13 +188,13 @@ def get_instances(self, output_path: str) -> List[Instance]: for download_file in download_files: ensure_file_downloaded( source_url=os.path.join(self.source_url, download_file), - target_path=os.path.join(self.output_path, download_file.replace("/", "_")), + target_path=os.path.join(output_path, download_file.replace("/", "_")), ) load_data_file = self.data_file.replace("/", "_") dataset: List[Dict] = self.load_dataset( scenario_name=self.scenario_name, - data_file=os.path.join(self.output_path, load_data_file), + data_file=os.path.join(output_path, load_data_file), dataset_size=self.dataset_size, few_shot_num=self.few_shot_num, prompt_type=self.prompt_type,