Skip to content

Commit

Permalink
add option for adding non retrieved docs
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Aug 29, 2024
1 parent 270e885 commit 7976828
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions lightning_ir/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def __init__(
sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"] = "top",
targets: Literal["relevance", "subtopic_relevance", "rank", "score"] | None = None,
normalize_targets: bool = False,
add_non_retrieved_docs: bool = False,
) -> None:
self.run_path = None
if Path(run_path_or_id).is_file():
Expand All @@ -214,8 +215,6 @@ def __init__(
else:
dataset = str(run_path_or_id)
super().__init__(dataset)
if depth != -1 and sample_size == -1:
sample_size = depth
self.depth = depth
self.sample_size = sample_size
self.sampling_strategy = sampling_strategy
Expand All @@ -238,15 +237,17 @@ def __init__(
query_ids = run_query_ids.intersection(qrels_query_ids)
if len(run_query_ids.difference(qrels_query_ids)):
self.run = self.run[self.run["query_id"].isin(query_ids)]
# outer join if docs are from ir_datasets else only keep docs in run
how = "left"
if self._docs is None and add_non_retrieved_docs:
how = "outer"
self.run = self.run.merge(
self.qrels.loc[pd.IndexSlice[query_ids, :]].add_prefix("relevance_", axis=1),
on=["query_id", "doc_id"],
how=(
"outer" if self._docs is None else "left"
), # outer join if docs are from ir_datasets else only keep docs in run
how=how,
)

if sample_size != -1:
if self.sample_size != -1:
num_docs_per_query = self.run.groupby("query_id").transform("size")
self.run = self.run[num_docs_per_query >= self.sample_size]

Expand Down

0 comments on commit 7976828

Please sign in to comment.