Skip to content

Commit

Permalink
Correct return estimators
Browse files Browse the repository at this point in the history
  • Loading branch information
yura-hb committed Mar 3, 2024
1 parent 4e45123 commit 0188553
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 58 deletions.
16 changes: 15 additions & 1 deletion diploma_thesis/agents/utils/return_estimator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@

from .return_estimator import Return as ReturnEstimator
from .estimator import Estimator as ReturnEstimator

from .no import No
from .gae import GAE
from .n_step import NStep

from utils import from_cli as _from_cli
from functools import partial

key_to_class = {
'no': No,
'gae': GAE,
'n_step': NStep
}

from_cli = partial(_from_cli, key_to_class=key_to_class)
64 changes: 7 additions & 57 deletions diploma_thesis/agents/utils/return_estimator/estimator.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,17 @@


from dataclasses import dataclass
from abc import ABCMeta, abstractmethod
from typing import List
from agents.utils.memory import Record

from agents.utils.memory import Record

class Return:

@dataclass
class Configuration:
discount_factor: float
lambda_factor: float
n: int
vtrace_clip: float
trace_lambda: float
class Estimator(metaclass=ABCMeta):

def __init__(self, configuration: Configuration):
@property
@abstractmethod
def discount_factor(self) -> float:
pass

@abstractmethod
def update_returns(self, records: List[Record]) -> List[Record]:

pass


# def recursive(trajectory, V, done, args, compute_target_policy):
# require_policy = args.off_policy
# end = trajectory[-1]
#
# for _ in (range(args.n) if done else [0]):
# transition = trajectory[0]
#
# policy = None
#
# if require_policy:
# policy = compute_target_policy(V)
#
# G = 0
#
# if not done:
# G = V[end.next_state]
#
# for index in reversed(range(args.n)):
# tmp = trajectory[index]
#
# if tmp.terminal:
# continue
#
# weight = 1
#
# if require_policy:
# weight = policy[tmp.state, tmp.action] / tmp.action_prob
#
# if args.vtrace_clip is not None:
# weight = np.clip(weight, -args.vtrace_clip, args.vtrace_clip)
#
# G = weight * (tmp.reward + args.gamma * G) + (1 - weight) * V[tmp.state]
#
# if args.trace_lambda is not None:
# G_1 = tmp.reward + (1 - tmp.done) * args.gamma * V[tmp.next_state]
#
# G = (1 - args.trace_lambda) * G_1 + args.trace_lambda * G
#
# if done:
# trajectory.append(Transition(0.0, 0.0, 0.0, 0.0, 0.0, False, 0.0, True))
#
# V[transition.state] += args.alpha * (G - V[transition.state])
27 changes: 27 additions & 0 deletions diploma_thesis/agents/utils/return_estimator/gae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

from typing import Dict
from .estimator import *


class GAE(Estimator):

def __init__(self, discount_factor: float, lambda_: float):
super().__init__()
self._discount_factor = discount_factor
self._lambda = lambda_

@property
def discount_factor(self) -> float:
return self._discount_factor

def update_returns(self, records: List[Record]) -> List[Record]:
for i in range(len(records) - 1):
records[i].info['advantage'] = ...
records[i].info['return'] = ...

return []


@staticmethod
def from_cli(parameters: Dict):
return GAE(parameters['discount_factor'], parameters['lambda'])
78 changes: 78 additions & 0 deletions diploma_thesis/agents/utils/return_estimator/n_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@

from typing import Dict
from dataclasses import dataclass
from .estimator import *


class NStep(Estimator):

@dataclass
class Configuration:
discount_factor: float
lambda_factor: float
n: int
trace_lambda: float | None

@staticmethod
def from_cli(parameters: Dict):
return NStep.Configuration(
discount_factor=parameters.get('discount_factor', 0.99),
lambda_factor=parameters.get('lambda_factor', 0.95),
n=parameters.get('n', 1),
# vtrace_clip=parameters.get('vtrace_clip', None),
trace_lambda=parameters.get('trace_lambda', None)
)

def __init__(self, configuration: Configuration):
super().__init__()

self.configuration = configuration

def discount_factor(self) -> float:
return self.configuration.discount_factor ** self.configuration.n

def update_returns(self, records: List[Record]) -> List[Record]:
pass

# def recursive(trajectory, V, done, args, compute_target_policy):
# require_policy = args.off_policy
# end = trajectory[-1]
#
# for _ in (range(args.n) if done else [0]):
# transition = trajectory[0]
#
# policy = None
#
# if require_policy:
# policy = compute_target_policy(V)
#
# G = 0
#
# if not done:
# G = V[end.next_state]
#
# for index in reversed(range(args.n)):
# tmp = trajectory[index]
#
# if tmp.terminal:
# continue
#
# weight = 1
#
# if require_policy:
# weight = policy[tmp.state, tmp.action] / tmp.action_prob
#
# if args.vtrace_clip is not None:
# weight = np.clip(weight, -args.vtrace_clip, args.vtrace_clip)
#
# G = weight * (tmp.reward + args.gamma * G) + (1 - weight) * V[tmp.state]
#
# if args.trace_lambda is not None:
# G_1 = tmp.reward + (1 - tmp.done) * args.gamma * V[tmp.next_state]
#
# G = (1 - args.trace_lambda) * G_1 + args.trace_lambda * G
#
# if done:
# trajectory.append(Transition(0.0, 0.0, 0.0, 0.0, 0.0, False, 0.0, True))
#
# V[transition.state] += args.alpha * (G - V[transition.state])
21 changes: 21 additions & 0 deletions diploma_thesis/agents/utils/return_estimator/no.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

from typing import Dict
from .estimator import *


class No(Estimator):

def __init__(self, discount_factor: float):
super().__init__()
self._discount_factor = discount_factor

@property
def discount_factor(self) -> float:
return self._discount_factor

def update_returns(self, records: List[Record]) -> List[Record]:
return records

@staticmethod
def from_cli(parameters: Dict):
return No(parameters['discount_factor'])

0 comments on commit 0188553

Please sign in to comment.