diff --git a/morl_baselines/multi_policy/linear_support/linear_support.py b/morl_baselines/multi_policy/linear_support/linear_support.py index 2979c530..2fdb5c75 100644 --- a/morl_baselines/multi_policy/linear_support/linear_support.py +++ b/morl_baselines/multi_policy/linear_support/linear_support.py @@ -50,6 +50,7 @@ def __init__( self.weight_support = [] # List of weight vectors for each value vector in the CCS self.queue = [] self.iteration = 0 + self.ols_ended = False self.verbose = verbose for w in extrema_weights(self.num_objectives): self.queue.append((float("inf"), w)) @@ -102,6 +103,7 @@ def next_weight( if len(self.queue) == 0: if self.verbose: print("There are no corner weights in the queue. Returning None.") + self.ols_ended = True return None else: next_w = self.queue.pop(0)[1] @@ -134,8 +136,14 @@ def get_corner_weights(self, top_k: Optional[int] = None) -> List[np.ndarray]: return weights def ended(self) -> bool: - """Returns True if the queue is empty.""" - return len(self.queue) == 0 + """Returns True if there are no more corner weights to test. + + Warning: This method must be called AFTER calling next_weight(). + Ex: w = ols.next_weight() + if ols.ended(): + print("OLS ended.") + """ + return self.ols_ended def add_solution(self, value: np.ndarray, w: np.ndarray) -> List[int]: """Add new value vector optimal to weight w. @@ -375,10 +383,13 @@ def is_dominated(self, value: np.ndarray) -> bool: def _solve(w): return np.array(list(map(float, input().split())), dtype=np.float32) - num_objectives = 3 + num_objectives = 2 ols = LinearSupport(num_objectives=num_objectives, epsilon=0.0001, verbose=True) - while not ols.ended(): + while True: w = ols.next_weight() + if ols.ended(): + print("OLS ended.") + break print("w:", w) value = _solve(w) ols.add_solution(value, w) diff --git a/tests/test_algos.py b/tests/test_algos.py index bd51583c..77337754 100644 --- a/tests/test_algos.py +++ b/tests/test_algos.py @@ -113,8 +113,10 @@ def test_ols(): ols = LinearSupport(num_objectives=2, epsilon=0.1, verbose=False) policies = [] - while not ols.ended(): + while True: w = ols.next_weight() + if ols.ended(): + break new_policy = MOQLearning( env,