generated from rochacbruno/python-project-template
-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
148 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,17 @@ | ||
|
||
from .return_estimator import Return as ReturnEstimator | ||
from .estimator import Estimator as ReturnEstimator | ||
|
||
from .no import No | ||
from .gae import GAE | ||
from .n_step import NStep | ||
|
||
from utils import from_cli as _from_cli | ||
from functools import partial | ||
|
||
key_to_class = { | ||
'no': No, | ||
'gae': GAE, | ||
'n_step': NStep | ||
} | ||
|
||
from_cli = partial(_from_cli, key_to_class=key_to_class) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,67 +1,17 @@ | ||
|
||
|
||
from dataclasses import dataclass | ||
from abc import ABCMeta, abstractmethod | ||
from typing import List | ||
from agents.utils.memory import Record | ||
|
||
from agents.utils.memory import Record | ||
|
||
class Return: | ||
|
||
@dataclass | ||
class Configuration: | ||
discount_factor: float | ||
lambda_factor: float | ||
n: int | ||
vtrace_clip: float | ||
trace_lambda: float | ||
class Estimator(metaclass=ABCMeta): | ||
|
||
def __init__(self, configuration: Configuration): | ||
@property | ||
@abstractmethod | ||
def discount_factor(self) -> float: | ||
pass | ||
|
||
@abstractmethod | ||
def update_returns(self, records: List[Record]) -> List[Record]: | ||
|
||
pass | ||
|
||
|
||
# def recursive(trajectory, V, done, args, compute_target_policy): | ||
# require_policy = args.off_policy | ||
# end = trajectory[-1] | ||
# | ||
# for _ in (range(args.n) if done else [0]): | ||
# transition = trajectory[0] | ||
# | ||
# policy = None | ||
# | ||
# if require_policy: | ||
# policy = compute_target_policy(V) | ||
# | ||
# G = 0 | ||
# | ||
# if not done: | ||
# G = V[end.next_state] | ||
# | ||
# for index in reversed(range(args.n)): | ||
# tmp = trajectory[index] | ||
# | ||
# if tmp.terminal: | ||
# continue | ||
# | ||
# weight = 1 | ||
# | ||
# if require_policy: | ||
# weight = policy[tmp.state, tmp.action] / tmp.action_prob | ||
# | ||
# if args.vtrace_clip is not None: | ||
# weight = np.clip(weight, -args.vtrace_clip, args.vtrace_clip) | ||
# | ||
# G = weight * (tmp.reward + args.gamma * G) + (1 - weight) * V[tmp.state] | ||
# | ||
# if args.trace_lambda is not None: | ||
# G_1 = tmp.reward + (1 - tmp.done) * args.gamma * V[tmp.next_state] | ||
# | ||
# G = (1 - args.trace_lambda) * G_1 + args.trace_lambda * G | ||
# | ||
# if done: | ||
# trajectory.append(Transition(0.0, 0.0, 0.0, 0.0, 0.0, False, 0.0, True)) | ||
# | ||
# V[transition.state] += args.alpha * (G - V[transition.state]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
|
||
from typing import Dict | ||
from .estimator import * | ||
|
||
|
||
class GAE(Estimator): | ||
|
||
def __init__(self, discount_factor: float, lambda_: float): | ||
super().__init__() | ||
self._discount_factor = discount_factor | ||
self._lambda = lambda_ | ||
|
||
@property | ||
def discount_factor(self) -> float: | ||
return self._discount_factor | ||
|
||
def update_returns(self, records: List[Record]) -> List[Record]: | ||
for i in range(len(records) - 1): | ||
records[i].info['advantage'] = ... | ||
records[i].info['return'] = ... | ||
|
||
return [] | ||
|
||
|
||
@staticmethod | ||
def from_cli(parameters: Dict): | ||
return GAE(parameters['discount_factor'], parameters['lambda']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
|
||
from typing import Dict | ||
from dataclasses import dataclass | ||
from .estimator import * | ||
|
||
|
||
class NStep(Estimator): | ||
|
||
@dataclass | ||
class Configuration: | ||
discount_factor: float | ||
lambda_factor: float | ||
n: int | ||
trace_lambda: float | None | ||
|
||
@staticmethod | ||
def from_cli(parameters: Dict): | ||
return NStep.Configuration( | ||
discount_factor=parameters.get('discount_factor', 0.99), | ||
lambda_factor=parameters.get('lambda_factor', 0.95), | ||
n=parameters.get('n', 1), | ||
# vtrace_clip=parameters.get('vtrace_clip', None), | ||
trace_lambda=parameters.get('trace_lambda', None) | ||
) | ||
|
||
def __init__(self, configuration: Configuration): | ||
super().__init__() | ||
|
||
self.configuration = configuration | ||
|
||
def discount_factor(self) -> float: | ||
return self.configuration.discount_factor ** self.configuration.n | ||
|
||
def update_returns(self, records: List[Record]) -> List[Record]: | ||
pass | ||
|
||
# def recursive(trajectory, V, done, args, compute_target_policy): | ||
# require_policy = args.off_policy | ||
# end = trajectory[-1] | ||
# | ||
# for _ in (range(args.n) if done else [0]): | ||
# transition = trajectory[0] | ||
# | ||
# policy = None | ||
# | ||
# if require_policy: | ||
# policy = compute_target_policy(V) | ||
# | ||
# G = 0 | ||
# | ||
# if not done: | ||
# G = V[end.next_state] | ||
# | ||
# for index in reversed(range(args.n)): | ||
# tmp = trajectory[index] | ||
# | ||
# if tmp.terminal: | ||
# continue | ||
# | ||
# weight = 1 | ||
# | ||
# if require_policy: | ||
# weight = policy[tmp.state, tmp.action] / tmp.action_prob | ||
# | ||
# if args.vtrace_clip is not None: | ||
# weight = np.clip(weight, -args.vtrace_clip, args.vtrace_clip) | ||
# | ||
# G = weight * (tmp.reward + args.gamma * G) + (1 - weight) * V[tmp.state] | ||
# | ||
# if args.trace_lambda is not None: | ||
# G_1 = tmp.reward + (1 - tmp.done) * args.gamma * V[tmp.next_state] | ||
# | ||
# G = (1 - args.trace_lambda) * G_1 + args.trace_lambda * G | ||
# | ||
# if done: | ||
# trajectory.append(Transition(0.0, 0.0, 0.0, 0.0, 0.0, False, 0.0, True)) | ||
# | ||
# V[transition.state] += args.alpha * (G - V[transition.state]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
|
||
from typing import Dict | ||
from .estimator import * | ||
|
||
|
||
class No(Estimator): | ||
|
||
def __init__(self, discount_factor: float): | ||
super().__init__() | ||
self._discount_factor = discount_factor | ||
|
||
@property | ||
def discount_factor(self) -> float: | ||
return self._discount_factor | ||
|
||
def update_returns(self, records: List[Record]) -> List[Record]: | ||
return records | ||
|
||
@staticmethod | ||
def from_cli(parameters: Dict): | ||
return No(parameters['discount_factor']) |