forked from carpedm20/SPIRAL-tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
118 lines (92 loc) · 3.28 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import sys, signal
import tensorflow as tf
import trainer
import utils as ut
from envs import create_env
logger = ut.logging.get_logger()
def main(_):
from config import get_args
args = get_args()
ut.train.set_global_seed(args.seed + args.task)
spec = ut.tf.cluster_spec(
args.num_workers, 1, args.start_port)
cluster = tf.train.ClusterSpec(spec)
cluster_def = cluster.as_cluster_def()
def shutdown(signal, frame):
logger.warn('Received signal %s: exiting', signal)
sys.exit(128+signal)
signal.signal(signal.SIGHUP, shutdown)
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
#############################
# Prepare common envs
#############################
env = create_env(args)
queue_shapes = [
['actions', [len(env.action_sizes)]],
['states', env.observation_shape],
['rewards', []],
['values', [1]],
['features', [2, args.lstm_size]],
]
if args.conditional:
queue_shapes.append(['conditions', env.observation_shape])
else:
queue_shapes.append(['z', [args.z_dim]])
for idx, (name, shape) in enumerate(queue_shapes):
length = env.episode_length
if name == 'states':
length += 1
queue_shapes[idx][1] = [length] + shape
queue_shapes.extend([
('r', []),
])
trajectory_queue_size = \
args.policy_batch_size * max(5, args.num_workers)
replay_queue_size = \
args.disc_batch_size * max(5, args.num_workers)
#############################
# Run
#############################
if args.task == 0:
ut.train.save_args(args)
if args.job_name == "worker":
gpu_options = tf.GPUOptions(allow_growth=True)
tf_config = tf.ConfigProto(
allow_soft_placement=True,
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=2,
gpu_options=gpu_options)
server = tf.train.Server(
cluster_def,
job_name="worker",
task_index=args.task,
config=tf_config)
trainer.train(args, server, cluster, env, queue_shapes,
trajectory_queue_size, replay_queue_size)
else:
del env
server = tf.train.Server(
cluster_def, job_name="ps", task_index=args.task,
config=tf.ConfigProto(device_filters=["/job:ps"]))
with tf.device("/job:ps/task:{}".format(args.task)):
queue_size = args.policy_batch_size * args.num_workers
queue = tf.FIFOQueue(
trajectory_queue_size,
[tf.float32] * len(queue_shapes),
shapes=[shape for _, shape in queue_shapes],
shared_name='queue')
replay = tf.FIFOQueue(
replay_queue_size,
tf.float32,
shapes=dict(queue_shapes)['states'][1:],
shared_name='replay')
while True:
time.sleep(1000)
if __name__ == "__main__":
tf.app.run()