forked from StanfordMSL/Neural-Network-Reach
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpendulum_controlled.jl
72 lines (58 loc) · 2.01 KB
/
pendulum_controlled.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
using Plots, FileIO, MAT
include("reach.jl")
# Returns H-rep of various input sets
function input_constraints_pendulum(type::String)
# Each input specification is in the form Ax≤b
if type == "pendulum"
A = [1. 0 0; -1 0 0; 0 1 0; 0 -1 0; 0 0 1; 0 0 -1]
b = [π, π, π, π, 5, 5]
else
error("Invalid input constraint specification.")
end
return A, b
end
# Returns H-rep of various output sets
function output_constraints_pendulum(type::String)
# Each output specification is in the form Ayₒᵤₜ≤b
# The raw network outputs are unnormalized: yₒᵤₜ = Aₒᵤₜy + bₒᵤₜ
# Thus the output constraints for raw network outputs are: A*Aₒᵤₜ*y ≤ b - A*bₒᵤₜ
if type == "origin"
A = [1. 0; -1 0; 0 1; 0 -1]
b = deg2rad.([5., 5, 2, 2])
else
error("Invalid input constraint specification.")
end
return A, b
end
# Plot all polytopes
function plot_hrep(inv_set)
plt = plot(reuse = false, legend=false, xlabel="Angle (deg.)", ylabel="Angular Velocity (deg./s.)")
for (A,b) in inv_set
reg = (180/π)*HPolytope(constraints_list(A,b)) # Convert from rad to deg for plotting
if isempty(reg)
@show reg
error("Empty polyhedron.")
end
plot!(plt, reg, fontfamily=font(40, "Computer Modern"), yguidefont=(14) , xguidefont=(14), tickfont = (12))
end
return plt
end
###########################
######## SCRIPTING ########
###########################
# Load network
copies = 1 # copies = 1 is original network
nn_weights = "models/Pendulum/weights_controlled.npz"
nn_params = "models/Pendulum/norm_params_controlled.npz"
weights = pytorch_net(nn_weights, nn_params, copies)
Aᵢ, bᵢ = input_constraints_pendulum("pendulum")
Aₒ, bₒ = output_constraints_pendulum("origin")
# Run algorithm
@time begin
ap2input, ap2output, ap2map, ap2backward = compute_reach(weights, Aᵢ, bᵢ, [Aₒ], [bₒ])
end
@show length(ap2input)
# Load and plot control invariant set
inv_dict = load("models/Pendulum/pendulum_controlled_inv_set.jld2",)
inv_set = inv_dict["inv_set"]
plt = plot_hrep(inv_set)