-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
41 lines (33 loc) · 1.17 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import sys
import gym
import numpy as np
from network_model.atari_network_model import AtariNetworkModel
from env.fire_to_start_env import FireToStartEnv
from env.no_op_start_env import NoOpStartEnv
from env.frame_skip_env import FrameSkipEnv
from env.preprocessed_frame_env import PreprocessedFrameEnv
from env.frame_stack_env import FrameStackEnv
from agent.trainer_agent import TrainerAgent
from agent.tester_agent import TesterAgent
def main(argv):
if len(argv) != 3:
print('Usage: {} <environment-name> <weights-file>'.format(argv[0]))
return
np.random.seed(0)
# Create the environment, wrapped per the Nature paper.
env = gym.make(argv[1])
env = FireToStartEnv(env)
env = NoOpStartEnv(env)
env = FrameSkipEnv(env)
env = PreprocessedFrameEnv(env)
env = FrameStackEnv(env)
#env = gym.wrappers.Monitor(env, './', video_callable=lambda episode_id: episode_id == 40)
# Model file to load weights from.
model_file_name = argv[2]
# Load the model.
model = AtariNetworkModel(model_file_name, env, 'model').create_network()
# Create the agent and test the model.
agent = TesterAgent(env, model)
agent.run()
if __name__ == "__main__":
main(sys.argv)