generated from rochacbruno/python-project-template
-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
24 changed files
with
440 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
|
||
|
||
from .machine import * | ||
from typing import Dict | ||
|
||
|
||
class DeepQAgent(Machine): | ||
|
||
def is_trainable(self): | ||
return True | ||
|
||
def train_step(self): | ||
pass | ||
|
||
@staticmethod | ||
def from_cli(parameters: Dict): | ||
model = model_from_cli(parameters['model']) | ||
encoder = state_encoder_from_cli(parameters['encoder']) | ||
|
||
return DeepQAgent(model=model, state_encoder=encoder, memory=None) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
|
||
from typing import List, Dict | ||
|
||
import torch | ||
|
||
from agents.utils import NNCLI, Phase, PhaseUpdatable | ||
from agents.utils.action import ActionSelector, from_cli as action_selector_from_cli | ||
from .model import * | ||
from .rule import ALL_SCHEDULING_RULES | ||
from .rule import SchedulingRule | ||
|
||
|
||
class MultiRuleLinear(Model, PhaseUpdatable): | ||
|
||
def __init__(self, rules: List[SchedulingRule], model: NNCLI, action_selector: ActionSelector): | ||
super().__init__() | ||
self.rules = rules | ||
self.model = model | ||
self.action_selector = action_selector | ||
|
||
def update(self, phase: Phase): | ||
self.phase = phase | ||
|
||
for module in [self.model, self.action_selector]: | ||
if isinstance(module, PhaseUpdatable): | ||
module.update(phase) | ||
|
||
def __call__(self, state: State, parameters: MachineModel.Input) -> MachineModel.Record: | ||
if not self.model.is_connected: | ||
self.__connect__(state.shape) | ||
|
||
distribution = self.model(state, parameters) | ||
|
||
action, _ = self.action_selector(distribution) | ||
|
||
rule = self.rules[action] | ||
|
||
return MachineModel.Record( | ||
result=rule(parameters.machine, parameters.now), | ||
state=state, | ||
action=action | ||
) | ||
|
||
def __connect__(self, input_shape: torch.Size): | ||
output_layer = NNCLI.Configuration.Linear( | ||
dim=len(self.rules), | ||
activation='none', | ||
dropout=0 | ||
) | ||
|
||
self.model.connect(input_shape, output_layer) | ||
|
||
@staticmethod | ||
def from_cli(parameters: Dict): | ||
rules = parameters['rules'] | ||
|
||
if rules == "all": | ||
rules = [rule() for rule in ALL_SCHEDULING_RULES.values()] | ||
else: | ||
rules = [ALL_SCHEDULING_RULES[rule]() for rule in rules] | ||
|
||
nn_cli = NNCLI.Configuration.from_cli(parameters['nn']) | ||
|
||
action_selector = action_selector_from_cli(parameters['action_selector']) | ||
|
||
return MultiRuleLinear(rules, nn_cli, action_selector) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
import logging | ||
from abc import ABCMeta | ||
from typing import List, TypeVar | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
|
||
from .phase import WarmUpPhase, TrainingPhase, EvaluationPhase, Phase | ||
from .loggable import Loggable | ||
from .loggable import Loggable | ||
from .phase_updatable import PhaseUpdatable | ||
from .nn import NNCLI |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .action_selector import ActionSelector | ||
from .epsilon_greedy import EpsilonGreedy | ||
from .upper_confidence_bound import UpperConfidenceBound | ||
from .sample import Sample | ||
from .uniform import Uniform | ||
from .phase_selector import PhaseSelector | ||
from .cli import from_cli |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
|
||
from abc import ABCMeta, abstractmethod | ||
from typing import Dict, Tuple | ||
|
||
import torch | ||
|
||
|
||
class ActionSelector(metaclass=ABCMeta): | ||
|
||
def __init__(self): | ||
pass | ||
|
||
@abstractmethod | ||
def __call__(self, distribution: torch.FloatTensor) -> Tuple[int, torch.FloatTensor]: | ||
""" | ||
Args: | ||
distribution: Distribution over possible actions. Either q-values or probabilities. | ||
Returns: Index of the selected action and the probability of the selected action. | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
|
||
from .action_selector import ActionSelector | ||
from .epsilon_greedy import EpsilonGreedy | ||
from .upper_confidence_bound import UpperConfidenceBound | ||
from .sample import Sample | ||
from .uniform import Uniform | ||
from .phase_selector import PhaseSelector | ||
|
||
|
||
key_to_cls = { | ||
'epsilon_greedy': EpsilonGreedy, | ||
'upper_confidence_bound': UpperConfidenceBound, | ||
'sample': Sample, | ||
'uniform': Uniform, | ||
'phase_selector': PhaseSelector | ||
} | ||
|
||
|
||
def from_cli(parameters) -> ActionSelector: | ||
cls = key_to_cls[parameters['kind']] | ||
|
||
return cls.from_cli(parameters.get('parameters', {})) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
|
||
from .action_selector import * | ||
|
||
|
||
class EpsilonGreedy(ActionSelector): | ||
|
||
def __init__(self, epsilon: float): | ||
super().__init__() | ||
self.epsilon = epsilon | ||
|
||
def __call__(self, distribution: torch.FloatTensor) -> Tuple[int, torch.FloatTensor]: | ||
if torch.rand(1) < self.epsilon: | ||
return torch.randint(0, distribution.size(0), (1,)).item(), self.epsilon / distribution.size(0) | ||
|
||
return torch.argmax(distribution).item(), 1 - self.epsilon | ||
|
||
@staticmethod | ||
def from_cli(parameters: Dict): | ||
return EpsilonGreedy(parameters['epsilon']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from .action_selector import * | ||
|
||
|
||
class Greedy(ActionSelector): | ||
|
||
def __call__(self, distribution: torch.FloatTensor) -> Tuple[int, torch.FloatTensor]: | ||
return torch.argmax(distribution).item(), torch.tensor(1.0) | ||
|
||
@staticmethod | ||
def from_cli(parameters: Dict): | ||
return Greedy() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
|
||
from .action_selector import * | ||
from agents.utils import Phase | ||
|
||
|
||
class PhaseSelector(ActionSelector): | ||
|
||
def __init__(self, default: ActionSelector, phase_to_action_selector: Dict[Phase, ActionSelector]): | ||
super().__init__() | ||
self.phase = None | ||
self.default = default | ||
self.phase_to_action_selector = phase_to_action_selector | ||
|
||
def update(self, phase: Phase): | ||
self.phase = phase | ||
|
||
def __call__(self, distribution: torch.FloatTensor) -> Tuple[int, torch.FloatTensor]: | ||
if self.phase in self.phase_to_action_selector: | ||
return self.phase_to_action_selector[self.phase](distribution) | ||
|
||
return self.default(distribution) | ||
|
||
@staticmethod | ||
def from_cli(parameters: Dict): | ||
from .cli import from_cli | ||
from agents.utils.phase import from_cli as phase_from_cli | ||
|
||
default = from_cli(parameters['default']) | ||
|
||
phase_to_action_selector = { | ||
phase_from_cli(info['phase']): from_cli(info['action_selector']) | ||
for info in parameters['phases'] | ||
} | ||
|
||
return PhaseSelector(default, phase_to_action_selector) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import torch.distributions | ||
|
||
from .action_selector import * | ||
|
||
|
||
class Sample(ActionSelector): | ||
|
||
def __init__(self, is_distribution: bool = True): | ||
super().__init__() | ||
self.is_distribution = is_distribution | ||
|
||
def __call__(self, distribution: torch.FloatTensor) -> int: | ||
if self.is_distribution: | ||
distribution = torch.distributions.Categorical(probs=distribution) | ||
else: | ||
distribution = torch.distributions.Categorical(logits=distribution) | ||
|
||
action = distribution.sample() | ||
|
||
return action, distribution.probs[action] | ||
|
||
@staticmethod | ||
def from_cli(parameters: Dict): | ||
return Sample(is_distribution=parameters['is_distribution']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|
||
# TODO: - Implement |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from .action_selector import * | ||
|
||
|
||
class Uniform(ActionSelector): | ||
|
||
def __call__(self, distribution: torch.FloatTensor) -> Tuple[int, torch.FloatTensor]: | ||
action = torch.randint(distribution.size(0), (1,)).item() | ||
|
||
return action, 1.0 / distribution.size(0) | ||
|
||
@staticmethod | ||
def from_cli(parameters: Dict): | ||
return Uniform() |
27 changes: 27 additions & 0 deletions
27
diploma_thesis/agents/utils/action/upper_confidence_bound.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from action_selector import * | ||
|
||
|
||
class UpperConfidenceBound: | ||
|
||
def __init__(self, parameter: float): | ||
self.parameter = parameter | ||
self.counts = None | ||
|
||
def __call__(self, distribution: torch.FloatTensor) -> Tuple[int, torch.FloatTensor]: | ||
if self.counts is None: | ||
self.counts = torch.zeros_like(distribution) | ||
|
||
ucb = distribution + self.parameter * torch.sqrt(torch.log(self.counts.sum()) / self.counts) | ||
ucb = torch.nan_to_num(ucb, nan=float('inf')) | ||
|
||
action = torch.argmax(ucb).item() | ||
|
||
self.counts[action] += 1 | ||
|
||
# TODO: Derive correct probability | ||
|
||
return action, torch.tensor(1.0) | ||
|
||
@staticmethod | ||
def from_cli(parameters: Dict): | ||
return UpperConfidenceBound(parameter=parameters['parameter']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|
||
from .nn_cli import NNCLI |
Oops, something went wrong.