Skip to content

Commit

Permalink
fix default args
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Nov 21, 2024
1 parent 0eaa265 commit 0164c53
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions lightning_ir/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,14 +338,12 @@ def get_ib_idcs(self, output: LightningIROutput, batch: TrainBatch) -> Tuple[tor

class ScoreBasedInBatchLossFunction(InBatchLossFunction):

def __init__(
self,
min_target_diff: float,
pos_sampling_technique: Literal["first"] = "first",
neg_sampling_technique: Literal["all_and_non_first"] = "all_and_non_first",
max_num_neg_samples: int | None = None,
):
super().__init__(pos_sampling_technique, neg_sampling_technique, max_num_neg_samples)
def __init__(self, min_target_diff: float, max_num_neg_samples: int | None = None):
super().__init__(
pos_sampling_technique="first",
neg_sampling_technique="all_and_non_first",
max_num_neg_samples=max_num_neg_samples,
)
self.min_target_diff = min_target_diff

def _sort_mask(self, mask: torch.Tensor, num_queries: int, num_docs: int, batch: TrainBatch) -> torch.Tensor:
Expand Down

0 comments on commit 0164c53

Please sign in to comment.