Skip to content

Commit

Permalink
Fixes connected with compiled models
Browse files Browse the repository at this point in the history
  • Loading branch information
yura-hb committed Mar 12, 2024
1 parent 3b073a0 commit 36963a9
Show file tree
Hide file tree
Showing 37 changed files with 227 additions and 108 deletions.
4 changes: 4 additions & 0 deletions diploma_thesis/agents/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@
from .state import GraphState, TensorState, Graph
from .rl_agent import RLAgent
from .marl_agent import MARLAgent

import torch

torch._dynamo.config.suppress_errors = True
39 changes: 29 additions & 10 deletions diploma_thesis/agents/base/marl_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,14 @@

class MARLAgent(Generic[Key], RLAgent[Key]):

def __init__(self,
model: DeepPolicyModel,
state_encoder: StateEncoder,
trainer: RLTrainer,
is_model_distributed: bool):
super().__init__(model, state_encoder, trainer)

self.model: DeepPolicyModel | Dict[Key, DeepPolicyModel] = model
self.trainer: RLTrainer | Dict[Key, RLTrainer] = trainer
self.is_model_distributed = is_model_distributed
def __init__(self, is_model_distributed: bool, *args, **kwargs):
self.is_configured = False

super().__init__(*args, **kwargs)

self.model: DeepPolicyModel | Dict[Key, DeepPolicyModel] = self.model
self.trainer: RLTrainer | Dict[Key, RLTrainer] = self.trainer
self.is_model_distributed = is_model_distributed
self.keys = None

@property
Expand Down Expand Up @@ -99,10 +96,32 @@ def schedule(self, key: Key, parameters):
if not self.trainer[key].is_configured:
self.trainer[key].configure(model.policy)

if not self.is_compiled:
self.compile()

return result

def __model_for_key__(self, key: Key):
if self.is_model_distributed:
return self.model[key]

return self.model

def compile(self):
if not self.is_configured or self.is_compiled:
return

if not self.configuration.compile:
self.is_compiled = True
return

for _, value in self.trainer.items():
value.compile()

if self.is_model_distributed:
for _, value in self.model.items():
value.compile()
else:
self.model.compile()

self.is_compiled = True
5 changes: 5 additions & 0 deletions diploma_thesis/agents/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class Record:
def __call__(self, state: State, parameters: Input) -> Record:
pass

def compile(self):
pass


class DeepPolicyModel(Model[Input, State, Action, Result], PhaseUpdatable, metaclass=ABCMeta):

Expand All @@ -39,3 +42,5 @@ def update(self, phase: Phase):

self.policy.update(phase)

def compile(self):
self.policy.compile()
43 changes: 42 additions & 1 deletion diploma_thesis/agents/base/rl_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,32 @@
from .agent import *
from .model import DeepPolicyModel

from dataclasses import dataclass


@dataclass
class Configuration:
compile: bool = False

@staticmethod
def from_cli(parameters):
return Configuration(
compile=parameters.get('compile', False)
)


class RLAgent(Generic[Key], Agent[Key]):

def __init__(self, model: DeepPolicyModel, state_encoder: StateEncoder, trainer: RLTrainer):
def __init__(self,
model: DeepPolicyModel,
state_encoder: StateEncoder,
trainer: RLTrainer,
configuration: Configuration):
super().__init__(model, state_encoder)

self.is_compiled = False

self.configuration = configuration
self.model: DeepPolicyModel = model
self.trainer = trainer

Expand Down Expand Up @@ -52,4 +72,25 @@ def schedule(self, key, parameters):
if not self.trainer.is_configured:
self.trainer.configure(self.model.policy)

if not self.is_compiled:
self.compile()

return result

def __setstate__(self, state):
self.__dict__ = state

self.compile()

def compile(self):
if not self.configuration.compile:
self.is_compiled = True
return

if self.is_compiled:
return

self.trainer.compile()
self.model.compile()

self.is_compiled = True
9 changes: 7 additions & 2 deletions diploma_thesis/agents/machine/marl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict

from agents.base.marl_agent import MARLAgent
from agents.base.marl_agent import MARLAgent, Configuration
from agents.utils.rl import from_cli as rl_trainer_from_cli
from environment import MachineKey, ShopFloor
from .model import DeepPolicyMachineModel, from_cli as model_from_cli
Expand All @@ -19,9 +19,14 @@ def from_cli(parameters: Dict):
model = model_from_cli(parameters['model'])
encoder = state_encoder_from_cli(parameters['encoder'])
trainer = rl_trainer_from_cli(parameters['trainer'])
configuration = Configuration.from_cli(parameters)

is_model_distributed = parameters.get('is_model_distributed', True)

assert isinstance(model, DeepPolicyMachineModel), f"Model must conform to NNModel"

return MARLMachine(model, encoder, trainer, is_model_distributed)
return MARLMachine(is_model_distributed=is_model_distributed,
model=model,
state_encoder=encoder,
trainer=trainer,
configuration=configuration)
2 changes: 1 addition & 1 deletion diploma_thesis/agents/machine/model/deep_multi_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, rules: List[SchedulingRule], policy: Policy[MachineInput]):
def __call__(self, state: State, parameters: Input) -> DeepPolicyMachineModel.Record:
# No gradient descent based on decision on the moment
with torch.no_grad():
record = self.policy(state, parameters)
record = self.policy.select(state, parameters)
result = self.rules[record.action.item()](parameters.machine, parameters.now)

return DeepPolicyMachineModel.Record(result=result, record=record, batch_size=[])
Expand Down
2 changes: 1 addition & 1 deletion diploma_thesis/agents/machine/model/deep_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class DeepRule(DeepPolicyMachineModel):
def __call__(self, state: State, parameters: Input) -> DeepPolicyMachineModel.Record:
# No gradient descent based on decision on the moment
with torch.no_grad():
record = self.policy(state, parameters)
record = self.policy.select(state, parameters)
result = parameters.machine.queue[record.action.item()]

return DeepPolicyMachineModel.Record(result=result, record=record, batch_size=[])
Expand Down
5 changes: 3 additions & 2 deletions diploma_thesis/agents/machine/rl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict

from agents.base.rl_agent import RLAgent
from agents.base.rl_agent import RLAgent, Configuration
from agents.utils.rl import from_cli as rl_trainer_from_cli
from environment import MachineKey
from .model import DeepPolicyMachineModel, from_cli as model_from_cli
Expand All @@ -14,7 +14,8 @@ def from_cli(parameters: Dict):
model = model_from_cli(parameters['model'])
encoder = state_encoder_from_cli(parameters['encoder'])
trainer = rl_trainer_from_cli(parameters['trainer'])
configuration = Configuration.from_cli(parameters)

assert isinstance(model, DeepPolicyMachineModel), f"Model must conform to NNModel"

return RLMachine(model, encoder, trainer)
return RLMachine(model, encoder, trainer, configuration)
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, channels: int):
super().__init__()

self.channels = channels
self.norm = nn.InstanceNorm1d(num_features=channels)
self.norm = nn.InstanceNorm1d(num_features=1)

def forward(self, batch):
normalized = batch[:, :self.channels]
Expand Down
30 changes: 15 additions & 15 deletions diploma_thesis/agents/utils/policy/flexible_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,7 @@ def __get_values__(self, state):
def __get_actions__(self, state):
return self.action_model(state)

def forward(self, state: State, parameters: Input) -> Record:
values, actions = self.predict(state)
values, actions = values.squeeze(), actions.squeeze()
action, policy = self.action_selector(actions)
action = action if torch.is_tensor(action) else torch.tensor(action, dtype=torch.long)

info = TensorDict({
"policy": policy,
"values": values.detach().clone(),
"actions": actions.detach().clone()
}, batch_size=[])

return Record(state, action, info, batch_size=[])

def predict(self, state: State):
def forward(self, state: State):
actions = torch.tensor(0, dtype=torch.long)

if self.action_model is not None:
Expand All @@ -77,6 +63,20 @@ def predict(self, state: State):
case _:
raise ValueError(f"Policy estimation method {self.policy_estimation_method} is not supported")

def select(self, state: State, parameters: Input) -> Record:
values, actions = self.__call__(state)
values, actions = values.squeeze(), actions.squeeze()
action, policy = self.action_selector(actions)
action = action if torch.is_tensor(action) else torch.tensor(action, dtype=torch.long)

info = TensorDict({
"policy": policy,
"values": values.detach().clone(),
"actions": actions.detach().clone()
}, batch_size=[])

return Record(state, action, info, batch_size=[])

def __configure__(self):
if self.noise_parameters is not None:
self.action_model.to_noisy(self.noise_parameters)
Expand Down
2 changes: 1 addition & 1 deletion diploma_thesis/agents/utils/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Record:
class Policy(Generic[Input], nn.Module, PhaseUpdatable, metaclass=ABCMeta):

@abstractmethod
def predict(self, state: State) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
def select(self, state, parameters):
pass

def clone(self):
Expand Down
4 changes: 2 additions & 2 deletions diploma_thesis/agents/utils/rl/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
class DoubleDeepQTrainer(DeepQTrainer):

def estimate_q(self, model: Policy, batch: Record | tensordict.TensorDictBase):
_, actions = model.predict(batch.next_state)
_, actions = model(batch.next_state)
orig_q = actions[range(batch.shape[0]), batch.action]

best_actions = actions.max(dim=-1).indices

target = self.target_model.predict(batch.next_state)[1][range(batch.shape[0]), best_actions]
target = self.target_model(batch.next_state)[1][range(batch.shape[0]), best_actions]

q = batch.reward + self.return_estimator.discount_factor * target * (1 - batch.done.int())
actions[range(batch.shape[0]), batch.action] = q
Expand Down
14 changes: 10 additions & 4 deletions diploma_thesis/agents/utils/rl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Configuration:
def from_cli(parameters: Dict):
return DeepQTrainer.Configuration(
decay=parameters.get('decay', 0.99),
update_steps=parameters.get('update_steps', 100),
update_steps=parameters.get('update_steps', 20),
prior_eps=parameters.get('prior_eps', 1e-6)
)

Expand All @@ -43,7 +43,7 @@ def __train__(self, model: Policy):
with torch.no_grad():
q_values, td_error = self.estimate_q(model, batch)

_, actions = model.predict(batch.state)
_, actions = model(batch.state)
loss = self.loss(actions, q_values)

self.optimizer.zero_grad()
Expand All @@ -64,10 +64,10 @@ def estimate_q(self, model: Policy, batch: Record | tensordict.TensorDictBase):
# Note:
# The idea is that we compute the Q-values only for performed actions. Other actions wouldn't be updated,
# because there will be zero loss and so zero gradient
_, actions = model.predict(batch.next_state)
_, actions = model(batch.next_state)
orig_q = actions.clone()[range(batch.shape[0]), batch.action]

_, target = self.target_model.predict(batch.next_state)
_, target = self.target_model(batch.next_state)
target = target.max(dim=1).values

q = batch.reward + self.return_estimator.discount_factor * target * (1 - batch.done.int())
Expand All @@ -81,6 +81,12 @@ def estimate_q(self, model: Policy, batch: Record | tensordict.TensorDictBase):
def target_model(self):
return self._target_model.module

def compile(self):
if not self.is_configured:
return

self.target_model.compile()

@classmethod
def from_cli(cls,
parameters,
Expand Down
2 changes: 1 addition & 1 deletion diploma_thesis/agents/utils/rl/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __train__(self, model: Policy):

def __step__(self, batch: Record, model: Policy):
advantages = batch.info[Record.ADVANTAGE_KEY]
value, logits = model.predict(batch.state)
value, logits = model(batch.state)
value = value[torch.arange(batch.shape[0]), batch.action]
distribution = torch.distributions.Categorical(logits=logits)

Expand Down
9 changes: 8 additions & 1 deletion diploma_thesis/agents/utils/rl/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __train__(self, model: Policy):
baseline = torch.squeeze(baseline)

# Perform policy step
loss = self.loss(model.predict(batch.state)[1], batch.action)
loss = self.loss(model(batch.state)[1], batch.action)

if loss.numel() == 1:
raise ValueError('Loss should not have reduction to single value')
Expand All @@ -90,6 +90,13 @@ def __train__(self, model: Policy):
critic.optimizer.step()
self.record_loss(critic_loss, key=f'critic_{index}')

def compile(self):
if not self.is_configured:
return

for critic in self.critics:
critic.neural_network.compile()

@property
def critics(self):
return self.configuration.critics
Expand Down
3 changes: 3 additions & 0 deletions diploma_thesis/agents/utils/rl/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def store(self, sample: TrainingSample, model: Policy):

self.__train__(model)

def compile(self):
pass

def clear(self):
self.loss_cache = []
self.storage.clear()
Expand Down
1 change: 0 additions & 1 deletion diploma_thesis/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import argparse
from typing import Dict

import torch._dynamo
import yaml

from workflow import Workflow, Simulation, Tournament, MultiSimulation
Expand Down
Loading

0 comments on commit 36963a9

Please sign in to comment.