diff --git a/src/seamless_communication/streaming/agents/online_text_decoder.py b/src/seamless_communication/streaming/agents/online_text_decoder.py index 2bb54a75..3b5fedda 100644 --- a/src/seamless_communication/streaming/agents/online_text_decoder.py +++ b/src/seamless_communication/streaming/agents/online_text_decoder.py @@ -325,10 +325,7 @@ def policy(self, states: DecoderAgentStates) -> Action: blocked_ngrams = self.get_blocked_ngrams(states.target_indices) decoder_features_out = None - while ( - len(states.target_indices + pred_indices) < self.max_len(states) - and len(pred_indices) < self.max_consecutive_writes - ): + while True: index, prob, decoder_features = self.run_decoder(states, pred_indices) if decoder_features_out is None: @@ -361,6 +358,12 @@ def policy(self, states: DecoderAgentStates) -> Action: if prob < self.decision_threshold and not states.source_finished: break + if ( + len(states.target_indices + pred_indices) >= self.max_len(states) + or len(pred_indices) >= self.max_consecutive_writes + ): + break + pred_indices.append(index) if self.state_bag.step_nr == 0: self.state_bag.increment_step_nr(