Skip to content

Commit

Permalink
Fixes after launching the experiment. Implement new reward strategies
Browse files Browse the repository at this point in the history
  • Loading branch information
yura-hb committed Mar 6, 2024
1 parent cbb627d commit 42cdc67
Show file tree
Hide file tree
Showing 21 changed files with 504 additions and 57 deletions.
5 changes: 3 additions & 2 deletions diploma_thesis/agents/base/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@
@dataclass
class TrainingSample:
episode_id: int
records: List[Record]


@dataclass
class Slice(TrainingSample):
records: List[Record]
pass


@dataclass
class Trajectory(TrainingSample):
records: List[Record]
pass


class Agent(Generic[Key], Loggable, PhaseUpdatable, metaclass=ABCMeta):
Expand Down
9 changes: 6 additions & 3 deletions diploma_thesis/agents/utils/policy/discrete_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,15 @@ def predict(self, state: State):
values = torch.tensor(0, dtype=torch.long)
actions = torch.tensor(0, dtype=torch.long)

if self.value_model is not None:
values = self.value_model(state)

if self.action_model is not None:
actions = self.action_model(state)

if self.value_model is not None:
values = self.value_model(state)
values = values.expand(-1, self.n_actions)
else:
values = actions

match self.policy_estimation_method:
case PolicyEstimationMethod.INDEPENDENT:
return values, actions
Expand Down
2 changes: 1 addition & 1 deletion diploma_thesis/agents/utils/return_estimator/gae.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def discount_factor(self) -> float:
def update_returns(self, records: List[Record]) -> List[Record]:
coef = self._discount_factor * self._lambda

for i in reversed(range(records.batch_size[0])):
for i in reversed(range(len(records))):
next_value = 0 if i == len(records) - 1 else records[i + 1].info[Record.ADVANTAGE_KEY]
value = records[i].info[Record.VALUES_KEY]
advantage = records[i].reward + self._discount_factor * next_value - value
Expand Down
8 changes: 4 additions & 4 deletions diploma_thesis/agents/utils/return_estimator/n_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ def update_returns(self, records: List[Record]) -> List[Record]:

lambdas = torch.cumprod(lambdas, dim=0)

for i in range(records.batch_size[0]):
for i in range(len(records)):
action = records[i].action

next_state_value = records[i + 1].info[Record.VALUES_KEY] if i + 1 < len(records) else 0
next_state_value = records[i + 1].info[Record.VALUES_KEY][action] if i + 1 < len(records) else 0
next_state_value *= self.configuration.discount_factor

td_errors += [records[i].reward + next_state_value - records[i].info[Record.VALUES_KEY]]
td_errors += [records[i].reward + next_state_value - records[i].info[Record.VALUES_KEY][action]]

if self.configuration.off_policy:
action_probs = torch.nn.functional.softmax(records[i].info[Record.ACTION_KEY], dim=0)
Expand All @@ -73,7 +73,7 @@ def update_returns(self, records: List[Record]) -> List[Record]:
else:
off_policy_weights += [1]

for i in range(records.batch_size[0]):
for i in range(len(records)):
g = records[i].info[Record.VALUES_KEY][records[i].action]
n = min(self.configuration.n, len(records) - i)

Expand Down
6 changes: 3 additions & 3 deletions diploma_thesis/agents/utils/rl/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,13 @@ def store(self, sample: TrainingSample):
records = self.__prepare__(sample)
records.info['episode'] = torch.full(records.reward.shape, sample.episode_id, device=records.reward.device)

self.memory.store(records.view(-1))
self.memory.store(records)

if self.train_schedule != TrainSchedule.ON_STORE:
return

self.__train__(sample.model)


def clear(self):
self.loss_cache = []
self.memory.clear()
Expand All @@ -102,6 +101,7 @@ def __prepare__(self, sample: TrainingSample) -> Record:
match sample:
case Trajectory(_, records):
updated = self.return_estimator.update_returns(records)
updated = [record.view(-1) for record in updated]
updated = torch.cat(updated, dim=0)

return updated
Expand All @@ -111,6 +111,6 @@ def __prepare__(self, sample: TrainingSample) -> Record:

updated = self.return_estimator.update_returns(records)

return updated[0]
return updated[0].view(-1)
case _:
raise ValueError(f'Unknown sample type: {type(sample)}')
Loading

0 comments on commit 42cdc67

Please sign in to comment.