Skip to content

Commit

Permalink
add sub query batching for large models
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Sep 4, 2024
1 parent 76a3b66 commit 080ff18
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion lightning_ir/cross_encoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,43 @@ class CrossEncoderOutput(LightningIROutput):
class CrossEncoderModel(LightningIRModel):
config_class: Type[CrossEncoderConfig] = CrossEncoderConfig

ALLOW_BATCHING = True

def __init__(self, config: CrossEncoderConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.config: CrossEncoderConfig
self.linear = torch.nn.Linear(config.hidden_size, 1, bias=config.linear_bias)

def batched_backbone_forward(self, encoding: BatchEncoding) -> torch.Tensor:
if not self.ALLOW_BATCHING:
return self.backbone_forward(encoding)
batch_size = encoding["input_ids"].shape[0]
outputs = []
sub_encoding = encoding
while True:
try:
outputs.append(self.backbone_forward(sub_encoding).last_hidden_state)
break
except RuntimeError as e:
if "CUDA out of memory" in str(e):
batch_size = batch_size // 2
if batch_size == 0:
raise e
sub_encoding = BatchEncoding(**{key: value[:batch_size] for key, value in sub_encoding.items()})
else:
raise e
if batch_size == encoding["input_ids"].shape[0]:
return outputs[0]
num_batches = encoding["input_ids"].shape[0] // batch_size - 1
for i in range(1, num_batches):
sub_encoding = BatchEncoding(
**{key: value[batch_size * i : batch_size * (i + 1)] for key, value in encoding.items()}
)
outputs.append(self.backbone_forward(sub_encoding).last_hidden_state)
return torch.cat(outputs)

def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput:
embeddings = self.backbone_forward(**encoding).last_hidden_state
embeddings = self.batched_backbone_forward(**encoding)
embeddings = self._pooling(
embeddings, encoding.get("attention_mask", None), pooling_strategy=self.config.pooling_strategy
)
Expand Down

0 comments on commit 080ff18

Please sign in to comment.