Skip to content

Commit

Permalink
Implement TD memory buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
yura-hb committed Mar 4, 2024
1 parent 198648b commit 375231a
Show file tree
Hide file tree
Showing 14 changed files with 165 additions and 69 deletions.
1 change: 1 addition & 0 deletions diploma_thesis/agents/utils/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Record:
POLICY_KEY = "policy"
VALUES_KEY = "values"
REWARD_KEY = "reward"
ACTION_KEY = "actions"
ADVANTAGE_KEY = "advantage"

state: State
Expand Down
5 changes: 3 additions & 2 deletions diploma_thesis/agents/utils/nn/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from .layer import Layer

from .linear import Linear
from .common import Flatten, InstanceNorm
from .common import Flatten, InstanceNorm, LayerNorm
from .activation import Activation
from .merge import Merge
from .noisy_linear import NoisyLinear
from .graph import GraphLayer
from .partial_instance_norm_1d import PartialInstanceNorm1d

Expand All @@ -15,7 +16,7 @@
'linear': Linear,
'flatten': Flatten,
'activation': Activation,
'layer_norm': ...,
'layer_norm': LayerNorm,
'instance_norm': InstanceNorm,
'partial_instance_norm': PartialInstanceNorm1d,
'noisy_linear': ...,
Expand Down
30 changes: 15 additions & 15 deletions diploma_thesis/agents/utils/nn/layers/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

from .layer import *
from typing import List


class Flatten(Layer):
Expand All @@ -18,21 +19,20 @@ def from_cli(cls, parameters: dict) -> 'Layer':
return Flatten()


# class LayerNorm(Layer):
#
# def __init__(self):
# super().__init__()
#
# self.layer = nn.LayerNorm()
#
#
# def __call__(self, batch: torch.FloatTensor) -> torch.FloatTensor:
# return self.layer(batch)
#
#
# @classmethod
# def from_cli(cls, parameters: dict) -> 'Layer':
# return LayerNorm(dimension=parameters['dimension'])
class LayerNorm(Layer):

def __init__(self, normalized_shape: int | List[int]):
super().__init__()

self.layer = nn.LayerNorm(normalized_shape=normalized_shape)

def __call__(self, batch: torch.FloatTensor) -> torch.FloatTensor:
return self.layer(batch)

@classmethod
def from_cli(cls, parameters: dict) -> 'Layer':
return LayerNorm(normalized_shape=parameters['normalized_shape'])


class InstanceNorm(Layer):

Expand Down
1 change: 1 addition & 0 deletions diploma_thesis/agents/utils/nn/layers/noisy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn as nn


# Taken from https://github.com/thomashirtz/noisy-networks/blob/main/noisynetworks.py


Expand Down
8 changes: 5 additions & 3 deletions diploma_thesis/agents/utils/policy/discrete_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def __init__(self,
value_model: NeuralNetwork,
action_model: NeuralNetwork,
action_selector: ActionSelector,
value_estimation_method: PolicyEstimationMethod = PolicyEstimationMethod.INDEPENDENT):
policy_method: PolicyEstimationMethod = PolicyEstimationMethod.INDEPENDENT):
super().__init__()

self.n_actions = n_actions
self.value_model = value_model
self.action_model = action_model
self.action_selector = action_selector
self.policy_estimation_method = value_estimation_method
self.policy_estimation_method = policy_method

self.__configure__()

Expand Down Expand Up @@ -87,5 +87,7 @@ def from_cli(parameters: Dict) -> 'Policy':
value_model = NeuralNetwork.from_cli(parameters['value_model']) if parameters.get('value_model') else None
action_model = NeuralNetwork.from_cli(parameters['action_model']) if parameters.get('action_model') else None
action_selector = action_selector_from_cli(parameters['action_selector'])
policy_method = PolicyEstimationMethod(parameters['policy_method']) \
if parameters.get('policy_method') else PolicyEstimationMethod.INDEPENDENT

return DiscreteAction(n_actions, value_model, action_model, action_selector)
return DiscreteAction(n_actions, value_model, action_model, action_selector, policy_method)
4 changes: 2 additions & 2 deletions diploma_thesis/agents/utils/policy/policy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

from abc import ABCMeta, abstractmethod
from dataclasses import field
from typing import TypeVar, Generic
from typing import TypeVar, Generic, Tuple

import torch
from torch import nn
Expand Down Expand Up @@ -30,7 +30,7 @@ def __call__(self, state: State, parameters: Input) -> Record:
pass

@abstractmethod
def predict(self, state: State) -> torch.FloatTensor:
def predict(self, state: State) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
pass

@abstractmethod
Expand Down
6 changes: 3 additions & 3 deletions diploma_thesis/agents/utils/return_estimator/n_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ def update_returns(self, records: List[Record]) -> List[Record]:
for i in range(len(records)):
action = records[i].action

next_state_value = records[i + 1].info[Record.VALUES_KEY][action] if i + 1 < len(records) else 0
next_state_value = records[i + 1].info[Record.VALUES_KEY] if i + 1 < len(records) else 0
next_state_value *= self.configuration.discount_factor

td_errors += [records[i].reward + next_state_value - records[i].info[Record.VALUES_KEY][action]]
td_errors += [records[i].reward + next_state_value - records[i].info[Record.VALUES_KEY]]

if self.configuration.off_policy:
action_probs = torch.nn.functional.softmax(records[i].info[Record.VALUES_KEY], dim=0)
action_probs = torch.nn.functional.softmax(records[i].info[Record.ACTION_KEY], dim=0)
weight = action_probs[action] / (records[i].info[Record.POLICY_KEY][action] + 1e-10)
weight = torch.nan_to_num(weight, nan=1.0, posinf=1.0, neginf=1.0)

Expand Down
14 changes: 7 additions & 7 deletions diploma_thesis/agents/utils/rl/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@

from .dqn import DeepQTrainer
from agents.utils.memory import Record
from agents.base.model import DeepPolicyModel
from agents.utils.policy import Policy


class DoubleDeepQTrainer(DeepQTrainer):

def estimate_q(self, model: DeepPolicyModel, batch: Record | tensordict.TensorDictBase):
q_values = model.predict(batch.next_state)
orig_q = q_values[range(batch.shape[0]), batch.action]
def estimate_q(self, model: Policy, batch: Record | tensordict.TensorDictBase):
_, actions = model.predict(batch.next_state)
orig_q = actions[range(batch.shape[0]), batch.action]

best_actions = q_values.max(dim=-1).indices
best_actions = actions.max(dim=-1).indices

target = self.target_model.predict(batch.next_state)[range(batch.shape[0]), best_actions]

q = batch.reward + self.return_estimator.discount_factor * target * (1 - batch.done)
q_values[range(batch.shape[0]), batch.action] = q
actions[range(batch.shape[0]), batch.action] = q

td_error = torch.square(orig_q - q)

return q_values, td_error
return actions, td_error
12 changes: 6 additions & 6 deletions diploma_thesis/agents/utils/rl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def train_step(self, model: Policy):
with torch.no_grad():
q_values, td_error = self.estimate_q(model, batch)

values = model.predict(batch.state)
loss = self.loss(values, q_values)
_, actions = model.predict(batch.state)
loss = self.loss(actions, q_values)

self.optimizer.zero_grad()

Expand All @@ -72,18 +72,18 @@ def estimate_q(self, model: Policy, batch: Record | tensordict.TensorDictBase):
# Note:
# The idea is that we compute the Q-values only for performed actions. Other actions wouldn't be updated,
# because there will be zero loss and so zero gradient
q_values = model.predict(batch.next_state)
orig_q = q_values.clone()[range(batch.shape[0]), batch.action]
_, actions = model.predict(batch.next_state)
orig_q = actions.clone()[range(batch.shape[0]), batch.action]

target = self.target_model.predict(batch.next_state)
target = target.max(dim=1).values

q = batch.reward + self.return_estimator.discount_factor * target * (1 - batch.done)
q_values[range(batch.shape[0]), batch.action] = q
actions[range(batch.shape[0]), batch.action] = q

td_error = torch.square(orig_q - q)

return q_values, td_error
return actions, td_error

def store(self, record: Record | List[Record]):
if isinstance(record, Record):
Expand Down
2 changes: 1 addition & 1 deletion diploma_thesis/agents/utils/rl/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _train_step(self, model: Policy):
return

advantages = batch.info[Record.ADVANTAGE_KEY]
logits = model.predict(batch.state)
value, logits = model.predict(batch.state)
distribution = torch.distributions.Categorical(logits=logits)

loss = 0
Expand Down
30 changes: 3 additions & 27 deletions diploma_thesis/simulator/episodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,19 @@
from agents.base import Agent
from functools import reduce
from typing import Dict
from .utils import Queue


class EpisodicSimulator(Simulator):
"""
A simulator, which stores trajectory.yml of agents and emits them for training after it finishes
"""

class Queue:

def __init__(self, is_distributed: bool):
self.is_distributed = is_distributed
self.queue = dict()

def store(self, shop_floor_id, key, moment, record):
self.queue[shop_floor_id] = self.queue.get(shop_floor_id, dict())

if self.is_distributed:
self.queue[shop_floor_id][key] = self.queue[shop_floor_id].get(key, dict())
self.queue[shop_floor_id][key][moment] = self.queue[shop_floor_id][key].get(moment, []) + [record]
else:
self.queue[shop_floor_id][moment] = self.queue[shop_floor_id].get(moment, []) + [record]

def pop(self, shop_floor_id):
if shop_floor_id not in self.queue:
return None

values = self.queue[shop_floor_id]

del self.queue[shop_floor_id]

return values

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.machine_queue = self.Queue(self.machine.is_distributed)
self.work_center_queue = self.Queue(self.work_center.is_distributed)
self.machine_queue = Queue(self.machine.is_distributed)
self.work_center_queue = Queue(self.work_center.is_distributed)

def did_prepare_machine_record(self, context: Context, machine: Machine, record: Record):
super().did_prepare_machine_record(context, machine, record)
Expand Down
33 changes: 30 additions & 3 deletions diploma_thesis/simulator/td.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,49 @@
from .simulator import *
from .utils import Queue

from agents.base import Agent

class TDSimulator(Simulator):
"""
A simulator, which estimates returns in Temporal Difference manner and send information for training as soon as
possible
"""

def __init__(self, memory: int = 100, *args, **kwargs):
super().__init__(*args, **kwargs)

self.memory = memory

self.machine_queue = Queue(self.machine.is_distributed)
self.work_center_queue = Queue(self.work_center.is_distributed)

def did_prepare_machine_record(self, context: Context, machine: Machine, record: Record):
super().did_prepare_machine_record(context, machine, record)

self.machine.store(machine.key, record)
self.__store_or_forward_td__(context, self.machine_queue, self.machine, machine.key, record)

def did_prepare_work_center_record(self, context: Context, work_center: WorkCenter, record: Record):
super().did_prepare_work_center_record(context, work_center, record)

self.work_center.store(work_center.key, record)
self.__store_or_forward_td__(context, self.work_center_queue, self.work_center, work_center.key, record)

@staticmethod
def from_cli(parameters, *args, **kwargs) -> Simulator:
return TDSimulator(*args, **kwargs)
return TDSimulator(parameters.get('memory', 1), *args, **kwargs)

def __store_or_forward_td__(self, context: Context, queue: Queue, agent: Agent, key, record):
if self.memory <= 1:
agent.store(key, record)
return

# Implement the idea of n-step memory
queue.store(context.shop_floor.id, key, context.moment, record)

if queue.group_len(context.shop_floor.id, key) > self.memory:
records = queue.pop_group(context.shop_floor.id, key)

agent.store(key, records)

queue.store_group(context.shop_floor.id, key, records[1:])

return
2 changes: 2 additions & 0 deletions diploma_thesis/simulator/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

from .queue import Queue
Loading

0 comments on commit 375231a

Please sign in to comment.