Skip to content

Commit

Permalink
test: add tests for changing data in ParametrizedInterpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianM-C committed Aug 16, 2024
1 parent fa1c3fd commit 8e4b786
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "LinearAlgebra", "OrdinaryDiffEq", "SafeTestsets", "Test", "ControlSystemsBase", "DataInterpolations"]
test = ["Aqua", "LinearAlgebra", "OrdinaryDiffEq", "Optimization", "SafeTestsets", "Test", "ControlSystemsBase", "DataInterpolations", "SciMLStructures", "SymbolicIndexingInterface"]
42 changes: 40 additions & 2 deletions test/Blocks/sources.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ using ModelingToolkitStandardLibrary.Blocks: smooth_sin, smooth_cos, smooth_damp
smooth_triangular, triangular, square
using OrdinaryDiffEq: ReturnCode.Success
using DataInterpolations
using SymbolicIndexingInterface
using SciMLStructures: SciMLStructures, Tunable
using Optimization

@testset "Constant" begin
@named src = Constant(k = 2)
Expand Down Expand Up @@ -479,8 +482,8 @@ end

@testset "ParametrizedInterpolation" begin
@variables y(t) = 0
@parameters u[1:15] = rand(15)
@parameters x[1:15] = 0:14.0
u = rand(15)
x = 0:14.0

@testset "LinearInterpolation" begin
@named i = ParametrizedInterpolation(LinearInterpolation, u, x)
Expand All @@ -493,6 +496,41 @@ end
sol = solve(prob)

@test SciMLBase.successful_retcode(sol)

prob2 = remake(prob, p=[i.data => ones(15)])
sol2 = solve(prob2)

@test SciMLBase.successful_retcode(sol2)
@test all(only.(sol2.u) .≈ sol2.t) # the solution for y' = 1 is y(t) = t

set_data! = setp(prob2, i.data)
set_data!(prob2, zeros(15))
sol3 = solve(prob2)
@test SciMLBase.successful_retcode(sol3)
@test iszero(sol3)

function loss(x, p)
prob0, set_data! = p
ps = parameter_values(prob0)
arr, repack, alias = SciMLStructures.canonicalize(Tunable(), ps)
T = promote_type(eltype(x), eltype(arr))
promoted_ps = SciMLStructures.replace(Tunable(), ps, T.(arr))
prob = remake(prob0; p = promoted_ps)

set_data!(prob, x)
sol = solve(prob)
sum(abs2.(only.(sol.u) .- sol.t))
end

set_data! = setp(prob, i.data)
of = OptimizationFunction(loss, AutoForwardDiff())
op = OptimizationProblem(of, u, (prob, set_data!), lb = zeros(15), ub = fill(2.0, 15))

# check that type changing works
@test length(ForwardDiff.gradient(x -> of(x, (prob, set_data!)), u)) == 15

r = solve(op, Optimization.LBFGS(), maxiters = 1000)
@test of(r.u, (prob, set_data!)) < of(u, (prob, set_data!))
end

@testset "BSplineInterpolation" begin
Expand Down

0 comments on commit 8e4b786

Please sign in to comment.