Skip to content

Commit

Permalink
Log horizons with log_all_multi_policy_metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
vaidas-sl committed Dec 9, 2023
1 parent 1716de2 commit 86eefab
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions morl_baselines/common/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def log_all_multi_policy_metrics(
global_step: int,
n_sample_weights: int = 50,
ref_front: Optional[List[np.ndarray]] = None,
horizons: Optional[List[np.ndarray]] = None,
):
"""Logs all metrics for multi-policy training.
Expand Down Expand Up @@ -178,10 +179,23 @@ def log_all_multi_policy_metrics(
},
commit=False,
)
front = wandb.Table(
columns=[f"objective_{i}" for i in range(1, reward_dim + 1)],
data=[p.tolist() for p in filtered_front],
)
columns = [f"objective_{i}" for i in range(1, reward_dim + 1)]
data = [p.tolist() for p in filtered_front]

# Filter the horizons array using filtered_front so that filtered horizons would contain horizons for the filtered front
if horizons is not None:
filtered_indices = []
for i, item in enumerate(current_front):
for filtered_item in filtered_front:
if np.array_equal(item, filtered_item):
filtered_indices.append(i)
break

filtered_horizons = [horizons[i] for i in filtered_indices]
columns.append("horizons")
data = [p.tolist() + [h] for p, h in zip(filtered_front, filtered_horizons)]

front = wandb.Table(columns=columns, data=data)
wandb.log({"eval/front": front})

# If PF is known, log the additional metrics
Expand Down

0 comments on commit 86eefab

Please sign in to comment.