-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathqmix_agent.py
126 lines (114 loc) · 5.57 KB
/
qmix_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import parl
import numpy as np
from utils import AvailableActionsSampler
import os
import time
class QMixAgent(parl.Agent):
def __init__(self, algorithm, exploration_start, min_exploration,
exploration_decay, update_target_interval):
self.alg = algorithm
self.global_step = 0
self.exploration = exploration_start
self.min_exploration = min_exploration
self.exploration_decay = exploration_decay
self.target_update_count = 0
self.update_target_interval = update_target_interval
def save(self, save_dir, agent_model_name, qmixer_model_name):
if not os.path.exists(save_dir):
os.mkdir(save_dir)
agent_model_path = os.path.join(save_dir, agent_model_name)
qmixer_model_path = os.path.join(save_dir, qmixer_model_name)
paddle.save(self.alg.agent_model.state_dict(), agent_model_path)
paddle.save(self.alg.qmixer_model.state_dict(), qmixer_model_path)
print('save model successfully!')
def restore(self, save_dir, agent_model_name, qmixer_model_name):
if not os.path.exists(save_dir):
os.mkdir(save_dir)
agent_model_path = os.path.join(save_dir, agent_model_name)
qmixer_model_path = os.path.join(save_dir, qmixer_model_name)
self.alg.agent_model.set_state_dict(paddle.load(agent_model_path))
self.alg.qmixer_model.set_state_dict(paddle.load(qmixer_model_path))
print('restore model successfully!')
def reset_agent(self, batch_size=1):
self.alg._init_hidden_states(batch_size)
def sample(self, obs, available_actions):
""" sample actions via epsilon-greedy
Args:
obs (np.ndarray): (n_agents, obs_shape)
available_actions (np.ndarray): (n_agents, n_actions)
Returns:
actions (np.ndarray): (n_agents, )
"""
epsilon = np.random.random()
if epsilon > self.exploration:
actions = self.predict(obs, available_actions)
else:
actions = AvailableActionsSampler(available_actions).sample()
self.exploration = max(self.min_exploration,
self.exploration - self.exploration_decay)
return actions
def predict(self, obs, available_actions):
""" take greedy actions
Args:
obs (np.ndarray): (n_agents, obs_shape)
available_actions (np.ndarray): (n_agents, n_actions)
Returns:
actions (np.ndarray): (n_agents, )
"""
obs = paddle.to_tensor(obs, dtype='float32')
available_actions = paddle.to_tensor(available_actions, dtype='int32')
agents_q, self.alg.hidden_states = self.alg.predict_local_q(
obs, self.alg.hidden_states)
# mask unavailable actions
unavailable_actions_mask = (available_actions == 0).cast('float32')
agents_q -= 1e8 * unavailable_actions_mask
actions = paddle.argmax(agents_q, axis=-1).detach().cpu().numpy()
return actions
def learn(self, state_batch, actions_batch, reward_batch, terminated_batch,
available_actions_batch, filled_batch, local_qs, target_local_qs):
'''
Args:
state (np.ndarray): (batch_size, T, state_shape)
actions (np.ndarray): (batch_size, T, n_agents)
reward (np.ndarray): (batch_size, T, 1)
terminated (np.ndarray): (batch_size, T, 1)
obs (np.ndarray): (batch_size, T, n_agents, obs_shape)
available_actions_batch (np.ndarray): (batch_size, T, n_agents, n_actions)
filled_batch (np.ndarray): (batch_size, T, 1)
Returns:
mean_loss (float): train loss
mean_td_error (float): train TD error
'''
if self.global_step % self.update_target_interval == 0:
self.alg.sync_target()
self.target_update_count += 1
self.global_step += 1
state_batch = paddle.to_tensor(state_batch, dtype='float32')
actions_batch = paddle.to_tensor(actions_batch, dtype='int64')
reward_batch = paddle.to_tensor(reward_batch, dtype='float32')
terminated_batch = paddle.to_tensor(terminated_batch, dtype='float32')
available_actions_batch = paddle.to_tensor(
available_actions_batch, dtype='int64')
filled_batch = paddle.to_tensor(filled_batch, dtype='float32')
mean_loss, mean_td_error = self.alg.learn(
state_batch, actions_batch, reward_batch, terminated_batch,
available_actions_batch, filled_batch, local_qs, target_local_qs)
return mean_loss, mean_td_error
#def localQ(self, state_batch, obs_batch):
#state_batch = paddle.to_tensor(state_batch, dtype='float32')
#obs_batch = paddle.to_tensor(obs_batch, dtype='float32')
#return self.alg.localQ(state_batch, obs_batch)