-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathepisode_runner_subgoal.py
132 lines (117 loc) · 6.49 KB
/
episode_runner_subgoal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import numpy as np
class EpisodeRunnerSubgoal:
def __init__(self, config, game, policy_function):
self.config = config
self.game = game
self.policy_function = policy_function
self.collision_cost = self.config['cost']['collision_cost']
self.is_constant_collision_cost = self.config['cost']['is_constant_collision_cost']
self.free_cost = self.config['cost']['free_cost']
self.is_constant_free_cost = self.config['cost']['is_constant_free_cost']
self.huber_loss_delta = self.config['cost']['huber_loss_delta']
self.repeat_train_trajectories = 0
if 'model' in config and 'repeat_train_trajectories' in config['model']:
self.repeat_train_trajectories = config['model']['repeat_train_trajectories']
def play_episodes(self, start_goal_pairs, top_level, is_train):
if is_train and self.repeat_train_trajectories:
start_goal_pairs_ = []
for _ in range(self.repeat_train_trajectories):
for s, g in start_goal_pairs:
new_pair = (s.copy(), g.copy())
start_goal_pairs_.append(new_pair)
start_goal_pairs = start_goal_pairs_
starts, goals = zip(*start_goal_pairs)
if top_level > 0:
middle_states = self.policy_function(starts, goals, top_level, is_train)
endpoints = np.array([np.array(starts)] + middle_states + [np.array(goals)])
else:
endpoints = np.array([np.array(starts)] + [np.array(goals)])
endpoints = np.swapaxes(endpoints, 0, 1)
endpoints = [np.squeeze(e, axis=0) for e in np.vsplit(endpoints, len(endpoints))]
results = {}
all_costs_queries = {}
for path_id, episode in enumerate(endpoints):
results[path_id] = episode
cost_queries = [(i, episode[i], episode[i+1]) for i in range(len(episode)-1)]
all_costs_queries[path_id] = cost_queries
all_cost_responses = self.game.test_predictions(all_costs_queries)
for path_id in results:
episode = results[path_id]
episode_cost_responses = all_cost_responses[path_id]
results[path_id] = self._process_endpoints(episode, episode_cost_responses, top_level)
return results
def _process_endpoints(self, endpoints, cost_responses, top_level):
is_valid_episode = True
base_costs = {}
splits = {}
# compute base costs:
for i in range(len(endpoints)-1):
start, end = endpoints[i], endpoints[i+1]
cost_response = cost_responses[i]
assert all(np.equal(start, cost_response[0])), 'i {} start {} cost_response[0] {} endpoints {}'.format(
i, start, cost_response[0], endpoints)
assert all(np.equal(end, cost_response[1])), 'i {} end {} cost_response[1] {} endpoints {}'.format(
i, end, cost_response[1], endpoints)
is_start_valid, is_goal_valid, free_length, collision_length = cost_response[2:]
is_segment_valid = collision_length == 0.0
cost = self._get_cost(free_length, collision_length)
base_costs[(i, i+1)] = (start, end, is_start_valid, is_goal_valid, cost)
is_valid_episode = is_valid_episode and is_segment_valid
# compute for the upper levels
splits[0] = base_costs
for l in range(1, top_level + 1):
steps = 2 ** (top_level - l)
splits[l] = {}
for i in range(steps):
start_index = i * (2 ** l)
end_index = (i + 1) * (2**l)
middle_index = int((start_index + end_index) / 2)
start, middle, end = endpoints[start_index], endpoints[middle_index], endpoints[end_index]
cost_from = splits[l-1] if l > 1 else base_costs
first_is_start_valid, first_is_goal_valid, first_cost = cost_from[(start_index, middle_index)][-3:]
second_is_start_valid, second_is_goal_valid, second_cost = cost_from[(middle_index, end_index)][-3:]
if first_cost is None or second_cost is None:
# if any segment is bad, ignore upper levels
is_start_valid, is_goal_valid, cost = None, None, None
else:
# if first_is_goal_valid != second_is_start_valid:
# print_and_log('bad segment agreement')
# print_and_log('start {} middle {} end {}'.format(start, middle, end))
# print_and_log('first_is_goal_valid {} second_is_start_valid {}'.format(first_is_goal_valid, second_is_start_valid))
# assert False
is_start_valid = first_is_start_valid
is_goal_valid = second_is_goal_valid
cost = first_cost + second_cost
splits[l][(start_index, end_index)] = (start, end, middle, is_start_valid, is_goal_valid, cost)
return endpoints, splits, base_costs, is_valid_episode
def _get_cost(self, segment_free, segment_collision):
if segment_collision == 0.0:
# segment is collision free
if self.is_constant_free_cost:
# pay a fixed cost for not being in collision
return self.free_cost
else:
# pay a distance relative cost for the free segment
return self._get_distance_cost(segment_free) * self.free_cost
else:
# segment in collision
if self.is_constant_collision_cost:
# pay a fixed cost for being in collision
return self.collision_cost
else:
# pay a distance relative cost for the free segment and for in-collision segments
free_cost = self._get_distance_cost(segment_free) * self.free_cost
collision_cost = self._get_distance_cost(segment_collision) * self.collision_cost
cost = free_cost + collision_cost
return cost
def _get_distance_cost(self, distance):
if self.config['cost']['type'] == 'linear':
return distance
elif self.config['cost']['type'] == 'huber':
return self._get_huber_loss(distance)
elif self.config['cost']['type'] == 'square':
return np.square(distance)
def _get_huber_loss(self, distance):
if distance < self.huber_loss_delta:
return 0.5 * distance * distance
return self.huber_loss_delta * (distance - 0.5 * self.huber_loss_delta)