-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreplay_buffer.py
94 lines (80 loc) · 3.33 KB
/
replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
import numpy as np
import random
class EpisodeExperience(object):
def __init__(self, episode_len):
self.max_len = episode_len
self.episode_state = []
self.episode_actions = []
self.episode_reward = []
self.episode_terminated = []
self.episode_obs = []
self.episode_available_actions = []
self.episode_filled = []
@property
def count(self):
return len(self.episode_state)
def add(self, state, actions, reward, terminated, obs, available_actions,
filled):
assert self.count < self.max_len
self.episode_state.append(state)
self.episode_actions.append(actions)
self.episode_reward.append(reward)
self.episode_terminated.append(terminated)
self.episode_obs.append(obs)
self.episode_available_actions.append(available_actions)
self.episode_filled.append(filled)
def get_data(self):
assert self.count == self.max_len
return np.array(self.episode_state), np.array(self.episode_actions),\
np.array(self.episode_reward), np.array(self.episode_terminated),\
np.array(self.episode_obs),\
np.array(self.episode_available_actions), np.array(self.episode_filled)
class EpisodeReplayBuffer(object):
def __init__(self, max_buffer_size):
self.max_buffer_size = max_buffer_size
self.buffer = deque(maxlen=max_buffer_size)
def add(self, episode_experience):
self.buffer.append(episode_experience)
@property
def count(self):
return len(self.buffer)
def sample_batch(self, batch_size):
batch = []
if self.count < batch_size:
batch = random.sample(self.buffer, self.count)
else:
batch = random.sample(self.buffer, batch_size)
s_batch, a_batch, r_batch, t_batch, obs_batch, available_actions_batch,\
filled_batch = [], [], [], [], [], [], []
for episode in batch:
s, a, r, t, obs, available_actions, filled = episode.get_data()
s_batch.append(s)
a_batch.append(a)
r_batch.append(r)
t_batch.append(t)
obs_batch.append(obs)
available_actions_batch.append(available_actions)
filled_batch.append(filled)
filled_batch = np.array(filled_batch)
r_batch = np.array(r_batch)
t_batch = np.array(t_batch)
a_batch = np.array(a_batch).astype('long')
obs_batch = np.array(obs_batch)
available_actions_batch = np.array(available_actions_batch).astype(
'long')
return s_batch, a_batch, r_batch, t_batch, obs_batch,\
available_actions_batch, filled_batch