diff --git a/morl_baselines/common/evaluation.py b/morl_baselines/common/evaluation.py index 07aca007..168c4400 100644 --- a/morl_baselines/common/evaluation.py +++ b/morl_baselines/common/evaluation.py @@ -145,6 +145,7 @@ def log_all_multi_policy_metrics( global_step: int, n_sample_weights: int = 50, ref_front: Optional[List[np.ndarray]] = None, + horizons: Optional[List[np.ndarray]] = None, ): """Logs all metrics for multi-policy training. @@ -178,10 +179,23 @@ def log_all_multi_policy_metrics( }, commit=False, ) - front = wandb.Table( - columns=[f"objective_{i}" for i in range(1, reward_dim + 1)], - data=[p.tolist() for p in filtered_front], - ) + columns = [f"objective_{i}" for i in range(1, reward_dim + 1)] + data = [p.tolist() for p in filtered_front] + + # Filter the horizons array using filtered_front so that filtered horizons would contain horizons for the filtered front + if horizons is not None: + filtered_indices = [] + for i, item in enumerate(current_front): + for filtered_item in filtered_front: + if np.array_equal(item, filtered_item): + filtered_indices.append(i) + break + + filtered_horizons = [horizons[i] for i in filtered_indices] + columns.append("horizons") + data = [p.tolist() + [h] for p, h in zip(filtered_front, filtered_horizons)] + + front = wandb.Table(columns=columns, data=data) wandb.log({"eval/front": front}) # If PF is known, log the additional metrics