diff --git a/diploma_thesis/agents/utils/policy/action_policy.py b/diploma_thesis/agents/utils/policy/action_policy.py index 6e66d40..b4db0ce 100644 --- a/diploma_thesis/agents/utils/policy/action_policy.py +++ b/diploma_thesis/agents/utils/policy/action_policy.py @@ -1,7 +1,8 @@ -from enum import StrEnum +import logging +import torch from typing import Dict -from agents.utils import NeuralNetwork, Phase +from agents.utils import NeuralNetwork from agents.utils.action import ActionSelector, from_cli as action_selector_from_cli from .policy import * @@ -32,6 +33,15 @@ def __post_init__(self): if self.noise_parameters is not None: self.model.to_noisy(self.noise_parameters) + def with_logger(self, logger: logging.Logger): + super().with_logger(logger) + + for module in [self.model, self.action_selector]: + if isinstance(module, Loggable): + module.with_logger(logger) + + return self + def configure(self, configuration: RunConfiguration): self.run_configuration = configuration @@ -44,7 +54,7 @@ def configure(self, configuration: RunConfiguration): self.model = self.model.to(configuration.device) def update(self, phase: Phase): - self.phase = phase + super().update(phase) for module in [self.model, self.action_selector]: if isinstance(module, PhaseUpdatable): diff --git a/diploma_thesis/agents/utils/policy/policy.py b/diploma_thesis/agents/utils/policy/policy.py index 38d157c..f3ac8dc 100644 --- a/diploma_thesis/agents/utils/policy/policy.py +++ b/diploma_thesis/agents/utils/policy/policy.py @@ -5,16 +5,15 @@ from enum import StrEnum from typing import TypeVar, Generic -import torch - from tensordict import TensorDict from tensordict.prototype import tensorclass from torch import nn from agents.base.state import State -from agents.utils import PhaseUpdatable +from agents.utils import PhaseUpdatable, Phase, EvaluationPhase from agents.utils.nn.layers.linear import Linear from agents.utils.run_configuration import RunConfiguration +from utils import Loggable Action = TypeVar('Action') Rule = TypeVar('Rule') @@ -37,13 +36,22 @@ class Keys(StrEnum): POLICY = 'policy' -class Policy(Generic[Input], nn.Module, PhaseUpdatable, metaclass=ABCMeta): +class Policy(nn.Module, Loggable, Generic[Input], PhaseUpdatable, metaclass=ABCMeta): def __init__(self): super().__init__() self.noise_parameters = None + def update(self, phase: Phase): + super().update(phase) + + if isinstance(phase, EvaluationPhase): + self.eval() + else: + self.train() + + @abstractmethod def forward(self, state: State): """ diff --git a/diploma_thesis/agents/workcenter/model/deep_multi_rule.py b/diploma_thesis/agents/workcenter/model/deep_multi_rule.py index 77ac94e..bd07e5b 100644 --- a/diploma_thesis/agents/workcenter/model/deep_multi_rule.py +++ b/diploma_thesis/agents/workcenter/model/deep_multi_rule.py @@ -8,7 +8,7 @@ from .rule import ALL_ROUTING_RULES, RoutingRule, IdleRoutingRule -class DeepMultiRule(DeepPolicyWorkCenterModel, DiscreteAction): +class DeepMultiRule(DeepPolicyWorkCenterModel): def __init__(self, rules: List[RoutingRule], policy: Policy[WorkCenterInput]): super().__init__(policy)