Skip to content

Commit

Permalink
Fix S2S pipeline edge case error (#243)
Browse files Browse the repository at this point in the history
  • Loading branch information
ibanesh authored Dec 8, 2023
1 parent c5ad5e2 commit 85be8ba
Showing 1 changed file with 7 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,7 @@ def policy(self, states: DecoderAgentStates) -> Action:
blocked_ngrams = self.get_blocked_ngrams(states.target_indices)
decoder_features_out = None

while (
len(states.target_indices + pred_indices) < self.max_len(states)
and len(pred_indices) < self.max_consecutive_writes
):
while True:
index, prob, decoder_features = self.run_decoder(states, pred_indices)

if decoder_features_out is None:
Expand Down Expand Up @@ -361,6 +358,12 @@ def policy(self, states: DecoderAgentStates) -> Action:
if prob < self.decision_threshold and not states.source_finished:
break

if (
len(states.target_indices + pred_indices) >= self.max_len(states)
or len(pred_indices) >= self.max_consecutive_writes
):
break

pred_indices.append(index)
if self.state_bag.step_nr == 0:
self.state_bag.increment_step_nr(
Expand Down

0 comments on commit 85be8ba

Please sign in to comment.