From aa11358bb2354ea4a5489dd5ed2f8a505e2e5f84 Mon Sep 17 00:00:00 2001 From: Abinesh Ramakrishnan <3632454+ibanesh@users.noreply.github.com> Date: Fri, 8 Dec 2023 15:15:44 -0800 Subject: [PATCH] Fix S2S pipeline edge case error (#243) --- .../streaming/agents/online_text_decoder.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/seamless_communication/streaming/agents/online_text_decoder.py b/src/seamless_communication/streaming/agents/online_text_decoder.py index 2bb54a75..3b5fedda 100644 --- a/src/seamless_communication/streaming/agents/online_text_decoder.py +++ b/src/seamless_communication/streaming/agents/online_text_decoder.py @@ -325,10 +325,7 @@ def policy(self, states: DecoderAgentStates) -> Action: blocked_ngrams = self.get_blocked_ngrams(states.target_indices) decoder_features_out = None - while ( - len(states.target_indices + pred_indices) < self.max_len(states) - and len(pred_indices) < self.max_consecutive_writes - ): + while True: index, prob, decoder_features = self.run_decoder(states, pred_indices) if decoder_features_out is None: @@ -361,6 +358,12 @@ def policy(self, states: DecoderAgentStates) -> Action: if prob < self.decision_threshold and not states.source_finished: break + if ( + len(states.target_indices + pred_indices) >= self.max_len(states) + or len(pred_indices) >= self.max_consecutive_writes + ): + break + pred_indices.append(index) if self.state_bag.step_nr == 0: self.state_bag.increment_step_nr(