Skip to content

Commit

Permalink
Add passthrough of logger from policy to logger. Set policy in eval()…
Browse files Browse the repository at this point in the history
… mode during evaluation phase
  • Loading branch information
yura-hb committed Mar 24, 2024
1 parent 3bac0e5 commit 265c91e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 8 deletions.
16 changes: 13 additions & 3 deletions diploma_thesis/agents/utils/policy/action_policy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from enum import StrEnum
import logging
import torch
from typing import Dict

from agents.utils import NeuralNetwork, Phase
from agents.utils import NeuralNetwork
from agents.utils.action import ActionSelector, from_cli as action_selector_from_cli
from .policy import *

Expand Down Expand Up @@ -32,6 +33,15 @@ def __post_init__(self):
if self.noise_parameters is not None:
self.model.to_noisy(self.noise_parameters)

def with_logger(self, logger: logging.Logger):
super().with_logger(logger)

for module in [self.model, self.action_selector]:
if isinstance(module, Loggable):
module.with_logger(logger)

return self

def configure(self, configuration: RunConfiguration):
self.run_configuration = configuration

Expand All @@ -44,7 +54,7 @@ def configure(self, configuration: RunConfiguration):
self.model = self.model.to(configuration.device)

def update(self, phase: Phase):
self.phase = phase
super().update(phase)

for module in [self.model, self.action_selector]:
if isinstance(module, PhaseUpdatable):
Expand Down
16 changes: 12 additions & 4 deletions diploma_thesis/agents/utils/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@
from enum import StrEnum
from typing import TypeVar, Generic

import torch

from tensordict import TensorDict
from tensordict.prototype import tensorclass
from torch import nn

from agents.base.state import State
from agents.utils import PhaseUpdatable
from agents.utils import PhaseUpdatable, Phase, EvaluationPhase
from agents.utils.nn.layers.linear import Linear
from agents.utils.run_configuration import RunConfiguration
from utils import Loggable

Action = TypeVar('Action')
Rule = TypeVar('Rule')
Expand All @@ -37,13 +36,22 @@ class Keys(StrEnum):
POLICY = 'policy'


class Policy(Generic[Input], nn.Module, PhaseUpdatable, metaclass=ABCMeta):
class Policy(nn.Module, Loggable, Generic[Input], PhaseUpdatable, metaclass=ABCMeta):

def __init__(self):
super().__init__()

self.noise_parameters = None

def update(self, phase: Phase):
super().update(phase)

if isinstance(phase, EvaluationPhase):
self.eval()
else:
self.train()


@abstractmethod
def forward(self, state: State):
"""
Expand Down
2 changes: 1 addition & 1 deletion diploma_thesis/agents/workcenter/model/deep_multi_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .rule import ALL_ROUTING_RULES, RoutingRule, IdleRoutingRule


class DeepMultiRule(DeepPolicyWorkCenterModel, DiscreteAction):
class DeepMultiRule(DeepPolicyWorkCenterModel):

def __init__(self, rules: List[RoutingRule], policy: Policy[WorkCenterInput]):
super().__init__(policy)
Expand Down

0 comments on commit 265c91e

Please sign in to comment.