From 8e4b78609c1efce0c469be11531be8064dada737 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Fri, 16 Aug 2024 04:54:30 +0300 Subject: [PATCH] test: add tests for changing data in ParametrizedInterpolation --- Project.toml | 5 ++++- test/Blocks/sources.jl | 42 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 7af74bc07..f19c02ee7 100644 --- a/Project.toml +++ b/Project.toml @@ -39,9 +39,12 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e" DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "LinearAlgebra", "OrdinaryDiffEq", "SafeTestsets", "Test", "ControlSystemsBase", "DataInterpolations"] +test = ["Aqua", "LinearAlgebra", "OrdinaryDiffEq", "Optimization", "SafeTestsets", "Test", "ControlSystemsBase", "DataInterpolations", "SciMLStructures", "SymbolicIndexingInterface"] diff --git a/test/Blocks/sources.jl b/test/Blocks/sources.jl index 2f8951b6a..c7f300e4d 100644 --- a/test/Blocks/sources.jl +++ b/test/Blocks/sources.jl @@ -6,6 +6,9 @@ using ModelingToolkitStandardLibrary.Blocks: smooth_sin, smooth_cos, smooth_damp smooth_triangular, triangular, square using OrdinaryDiffEq: ReturnCode.Success using DataInterpolations +using SymbolicIndexingInterface +using SciMLStructures: SciMLStructures, Tunable +using Optimization @testset "Constant" begin @named src = Constant(k = 2) @@ -479,8 +482,8 @@ end @testset "ParametrizedInterpolation" begin @variables y(t) = 0 - @parameters u[1:15] = rand(15) - @parameters x[1:15] = 0:14.0 + u = rand(15) + x = 0:14.0 @testset "LinearInterpolation" begin @named i = ParametrizedInterpolation(LinearInterpolation, u, x) @@ -493,6 +496,41 @@ end sol = solve(prob) @test SciMLBase.successful_retcode(sol) + + prob2 = remake(prob, p=[i.data => ones(15)]) + sol2 = solve(prob2) + + @test SciMLBase.successful_retcode(sol2) + @test all(only.(sol2.u) .≈ sol2.t) # the solution for y' = 1 is y(t) = t + + set_data! = setp(prob2, i.data) + set_data!(prob2, zeros(15)) + sol3 = solve(prob2) + @test SciMLBase.successful_retcode(sol3) + @test iszero(sol3) + + function loss(x, p) + prob0, set_data! = p + ps = parameter_values(prob0) + arr, repack, alias = SciMLStructures.canonicalize(Tunable(), ps) + T = promote_type(eltype(x), eltype(arr)) + promoted_ps = SciMLStructures.replace(Tunable(), ps, T.(arr)) + prob = remake(prob0; p = promoted_ps) + + set_data!(prob, x) + sol = solve(prob) + sum(abs2.(only.(sol.u) .- sol.t)) + end + + set_data! = setp(prob, i.data) + of = OptimizationFunction(loss, AutoForwardDiff()) + op = OptimizationProblem(of, u, (prob, set_data!), lb = zeros(15), ub = fill(2.0, 15)) + + # check that type changing works + @test length(ForwardDiff.gradient(x -> of(x, (prob, set_data!)), u)) == 15 + + r = solve(op, Optimization.LBFGS(), maxiters = 1000) + @test of(r.u, (prob, set_data!)) < of(u, (prob, set_data!)) end @testset "BSplineInterpolation" begin