-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils_memory.py
147 lines (131 loc) · 4.83 KB
/
utils_memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from typing import (
Tuple,
)
import torch
from utils_types import (
BatchAction,
BatchDone,
BatchNext,
BatchReward,
BatchState,
BatchWeight,
TensorStack5,
TorchDevice,
)
import math
import numpy as np
from copy import deepcopy as copy
class treearray:
def __init__(self, capacity):
self.__capacity = capacity
self.__bits = math.ceil(math.log(capacity, 2))
self.__max_len = 1 << self.__bits
self.__array = torch.zeros((self.__max_len + 1, 1), dtype=torch.float)
def add(self, loc, val):
# 单点加法
while loc < self.__max_len:
self.__array[loc] += val
loc += loc & (-loc)
def get_array(self):
return self.__array
def get_prefix_sum(self, loc):
# 得到一个前loc个值的和
val = 0
while loc != 0:
val += self.__array[loc]
loc -= loc & (-loc)
return val
def change(self, loc, val):
# 单点修改,不过要先查询之前的值才能加上去
nowval = self.get_prefix_sum(loc) - self.get_prefix_sum(loc - 1)
#print(val,nowval)
self.add(loc, val - nowval)
def get_minval(self,size):
return self.__array[1:size].min()
def search(self):
# 进行采样
sub_val = (1 << (self.__bits - 1))
right = self.__max_len
right_val = copy(self.__array[right])
while sub_val != 0:
left = right - sub_val
left_val = copy(self.__array[left])
if np.random.rand() < left_val / right_val:
right = left
right_val = left_val
else:
right_val -= left_val
sub_val //= 2
return right-1,right_val
class ReplayMemory(object):
def __init__(
self,
channels: int,
capacity: int,
device: TorchDevice,
EPSILON: float = 0.01,
ALPHA:float=0.6,
BETA: float = 0.4,
BETA_INC:float =0.001
) -> None:
self.__device = device
self.__capacity = capacity
self.__size = 0
self.__pos = 0
self.__ALPHA=ALPHA
self.__BETA=BETA
self.__BETA_INC=BETA_INC
self.__EPSILON=EPSILON
self.__m_states = torch.zeros(
(capacity, channels, 84, 84), dtype=torch.uint8)
self.__m_actions = torch.zeros((capacity, 1), dtype=torch.long)
self.__m_rewards = torch.zeros((capacity, 1), dtype=torch.int8)
self.__m_dones = torch.zeros((capacity, 1), dtype=torch.bool)
self.__treearray = treearray(capacity)
def push(
self,
folded_state: TensorStack5,
action: int,
reward: int,
done: bool,
agent,
) -> None:
self.__m_states[self.__pos] = folded_state
self.__m_actions[self.__pos, 0] = action
self.__m_rewards[self.__pos, 0] = reward
self.__m_dones[self.__pos, 0] = done
indices = [self.__pos] # torch.randint(0, high=self.__size, size=(batch_size,))
b_state = self.__m_states[indices, :4].to(self.__device).float()
b_next = self.__m_states[indices, 1:].to(self.__device).float()
b_action = self.__m_actions[indices].to(self.__device)
b_reward = self.__m_rewards[indices].to(self.__device).float()
b_done = self.__m_dones[indices].to(self.__device).float()
pri_val=agent.get_pri_val(b_state, b_action, b_reward, b_done, b_next)
pri_val=pri_val.cpu().reshape((1))
pri_val=math.pow(abs(pri_val)+self.__EPSILON,self.__ALPHA)
#print(pri_val,b_reward)
self.__treearray.change(self.__pos+1,pri_val)
self.__pos = (self.__pos + 1) % self.__capacity
self.__size = max(self.__size, self.__pos)
def sample(self, batch_size: int) -> Tuple[
BatchState,
BatchAction,
BatchReward,
BatchNext,
BatchWeight,
BatchDone,
]:
self.__BETA=np.min([1.0,self.__BETA+self.__BETA_INC])
temp = np.array([self.__treearray.search() for i in range(batch_size)])#torch.randint(0, high=self.__size, size=(batch_size,))
indices=temp[:,0];pri_vals=torch.Tensor(temp[:,1])
b_state = self.__m_states[indices, :4].to(self.__device).float()
b_next = self.__m_states[indices, 1:].to(self.__device).float()
b_action = self.__m_actions[indices].to(self.__device)
b_reward = self.__m_rewards[indices].to(self.__device).float()
min_pri=max(self.__treearray.get_minval(self.__size),0.0001)
b_weight=(pri_vals/min_pri).pow(-self.__BETA)
#print(pri_vals[0],min_pri,b_weight[0],'\n')
b_done = self.__m_dones[indices].to(self.__device).float()
return b_state, b_action, b_reward, b_next, b_weight,b_done
def __len__(self) -> int:
return self.__size