Skip to content

Commit

Permalink
Remove self.output_path
Browse files Browse the repository at this point in the history
  • Loading branch information
danielz02 committed Nov 7, 2023
1 parent 7763c0e commit ae8b0e4
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def __init__(self, perspective: str, data: str, demo_name: str, description: str
self.description = TASK_DESCRIPTIONS[description]
self.seeds = SEEDS[perspective]

def _get_instances(self, data_path: str, note: str) -> List[Instance]:
def _get_instances(self, data_path: str, note: str, output_path: str) -> List[Instance]:
instances: List[Instance] = []
target_path = os.path.join(self.output_path, data_path)
target_path = os.path.join(output_path, data_path)
ensure_directory_exists(os.path.split(target_path)[0])
ensure_file_downloaded(source_url=self.source_url + data_path, target_path=target_path) # to be filled
dataset = []
Expand Down Expand Up @@ -144,9 +144,9 @@ def get_instances(self, output_path: str) -> List[Instance]:
part = self.demo_name

asr_path = f"{self.perspective}/{exp}/{part}_asr/{seed}.jsonl"
instances.extend(self._get_instances(asr_path, f"asr_{seed}"))
instances.extend(self._get_instances(asr_path, f"asr_{seed}", output_path))
cacc_path = f"{self.perspective}/{exp}/{part}_cacc/{seed}.jsonl"
instances.extend(self._get_instances(cacc_path, f"cacc_{seed}"))
instances.extend(self._get_instances(cacc_path, f"cacc_{seed}", output_path))
else:
if self.perspective == "counterfactual":
if self.demo_name.find("cf") != -1:
Expand All @@ -159,7 +159,7 @@ def get_instances(self, output_path: str) -> List[Instance]:
data_path = f"{self.perspective}/{self.data}/{part}/{seed}.jsonl"
else:
raise ValueError(f"Nonexistent {self.perspective}")
instances.extend(self._get_instances(data_path, str(seed)))
instances.extend(self._get_instances(data_path, str(seed), output_path))
if self.demo_name in ["cf", "zero"]:
break
return instances
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_references(self, label: int) -> List[Reference]:
return references

def get_instances(self, output_path: str) -> List[Instance]:
data_path: str = os.path.join(self.output_path, self.sub_scenario)
data_path: str = os.path.join(output_path, self.sub_scenario)
url: str = os.path.join(self.source_url, self.sub_scenario)

ensure_file_downloaded(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,13 +299,13 @@ def get_file_instances(
return instances

def get_instances(self, output_path: str) -> List[Instance]:
target_test_path = os.path.join(self.output_path, self.test_data_file.replace("/", "_"))
target_test_path = os.path.join(output_path, self.test_data_file.replace("/", "_"))
ensure_file_downloaded(
source_url=os.path.join(self.source_url, self.test_data_file),
target_path=target_test_path,
)

target_train_path = os.path.join(self.output_path, self.train_data_file.replace("/", "_"))
target_train_path = os.path.join(output_path, self.train_data_file.replace("/", "_"))
ensure_file_downloaded(
source_url=os.path.join(self.source_url, self.train_data_file),
target_path=target_train_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def get_prompts(self, datasets):
return dataset

def get_instances(self, output_path: str) -> List[Instance]:
data_path: str = os.path.join(self.output_path, f"{self.ood_type}.json".replace("/", "_"))
data_path: str = os.path.join(output_path, f"{self.ood_type}.json".replace("/", "_"))

ensure_file_downloaded(
source_url=self.source_url.format(self.ood_type),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,13 @@ def get_instances(self, output_path: str) -> List[Instance]:
for download_file in download_files:
ensure_file_downloaded(
source_url=os.path.join(self.source_url, download_file),
target_path=os.path.join(self.output_path, download_file.replace("/", "_")),
target_path=os.path.join(output_path, download_file.replace("/", "_")),
)
load_data_file = self.data_file.replace("/", "_")

dataset: List[Dict] = self.load_dataset(
scenario_name=self.scenario_name,
data_file=os.path.join(self.output_path, load_data_file),
data_file=os.path.join(output_path, load_data_file),
dataset_size=self.dataset_size,
few_shot_num=self.few_shot_num,
prompt_type=self.prompt_type,
Expand Down

0 comments on commit ae8b0e4

Please sign in to comment.