Skip to content

Commit

Permalink
Add support for interpolating StaticArray data (#37)
Browse files Browse the repository at this point in the history
* Allow interpolation of StaticArrays

* Add comments

* Make derivation and integration work for SArray

* Update tests

* Bump version

* Update comment

[skip ci]
  • Loading branch information
jipolanco authored Apr 12, 2022
1 parent 3af291c commit 8d851d2
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 17 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BSplineKit"
uuid = "093aae92-e908-43d7-9660-e50ee39d5a0a"
authors = ["Juan Ignacio Polanco <[email protected]>"]
version = "0.10.0"
version = "0.11.0"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand All @@ -16,8 +16,8 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
ArrayLayouts = "0.7, 0.8"
BandedMatrices = "0.16, 0.17"
ArrayLayouts = "0.8"
BandedMatrices = "0.17"
FastGaussQuadrature = "0.4"
Interpolations = "0.13"
LazyArrays = "0.22"
Expand Down
16 changes: 10 additions & 6 deletions src/SplineInterpolations/SplineInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,30 +61,30 @@ struct SplineInterpolation{

function SplineInterpolation(
B::AbstractBSplineBasis, C::Factorization, x::AbstractVector,
)
::Type{T},
) where {T}
N = length(B)
size(C) == (N, N) ||
throw(DimensionMismatch("collocation matrix has wrong dimensions"))
length(x) == N ||
throw(DimensionMismatch("wrong number of collocation points"))
T = eltype(C)
s = Spline(undef, B, T) # uninitialised spline
new{typeof(s), typeof(C), typeof(x)}(s, C, x)
end
end

# Construct SplineInterpolation from basis and collocation points.
function SplineInterpolation(
::UndefInitializer, B, x::AbstractVector, ::Type{T},
) where {T}
::UndefInitializer, B, x::AbstractVector{Tx}, ::Type{T},
) where {Tx <: Real, T}
# Here we construct the collocation matrix and its LU factorisation.
N = length(B)
if length(x) != N
throw(DimensionMismatch(
"incompatible lengths of B-spline basis and collocation points"))
end
C = collocation_matrix(B, x, CollocationMatrix{T})
SplineInterpolation(B, lu!(C), x)
C = collocation_matrix(B, x, CollocationMatrix{Tx})
SplineInterpolation(B, lu!(C), x, T)
end

interpolation_points(S::SplineInterpolation) = S.x
Expand Down Expand Up @@ -185,7 +185,11 @@ function interpolate(
)
t = make_knots(x, order(k))
B = BSplineBasis(k, t; augment = Val(false)) # it's already augmented!

# If input data is integer, convert the spline element type to float.
# This also does the right thing when eltype(y) <: StaticArray.
T = float(eltype(y))

itp = SplineInterpolation(undef, B, x, T)
interpolate!(itp, y)
end
Expand Down
8 changes: 4 additions & 4 deletions src/Splines/spline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ function spline_kernel(
d = ntuple(Val(k)) do j
if j r
α = (x - t[j + n - k]) / (t[j + n - r + 1] - t[j + n - k])
(1 - α) * w[j - 1] + α * w[j]
T((1 - α) * w[j - 1] + α * w[j])
else
w[j]
end
Expand Down Expand Up @@ -263,7 +263,7 @@ function _diff(
# In this case, the B-spline that this coefficient is
# multiplying is zero everywhere, so we can set this to zero.
# From de Boor (2001, p. 117): "anything times zero is zero".
du[i] = 0
du[i] = zero(eltype(du))
else
du[i] = (k - m) * (du[i] - du[i - 1]) / dt
end
Expand Down Expand Up @@ -310,11 +310,11 @@ function _integral(::BSplineBasis, S::Spline)
t_int[end] = t_int[end - 1]

β = similar(u, N + 1)
β[1] = 0
β[1] = zero(eltype(β))

@inbounds for i in eachindex(u)
m = i + 1
β[m] = 0
β[m] = zero(eltype(β))
for j = 1:i
β[m] += u[j] * (t[j + k] - t[j]) / k
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
23 changes: 19 additions & 4 deletions test/interpolation.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
using Random
using StaticArrays: SVector

function test_interpolation(ord::BSplineOrder)
function test_interpolation(ord::BSplineOrder, ::Type{Ty} = Float64) where {Ty}
rng = MersenneTwister(42)
ndata = 40
xs = sort(randn(rng, ndata))
ys = randn(rng, ndata)

# This is Int if Ty = Int
# Int32 if Ty = SVector{N, Int32}
Tdata = eltype(Ty)

ys = if Tdata <: Integer
[rand(rng, Ty) .% Ty(16) for _ = 1:ndata] # values in -15:15
else
randn(rng, Ty, ndata)
end

k = order(ord)

@testset "No BCs" begin
Expand All @@ -30,7 +41,7 @@ function test_interpolation(ord::BSplineOrder)
for n = 2:(k ÷ 2)
Sder = Derivative(n) * S
for x boundaries(basis(S))
@test abs(Sder(x)) < abs(S(x)) * 1e-7
@test norm(Sder(x)) < norm(S(x)) * 1e-7
end
end
end
Expand All @@ -41,6 +52,10 @@ end

@testset "Interpolation" begin
@testset "k = $k" for k (3, 4, 6, 8)
test_interpolation(BSplineOrder(k))
test_interpolation(BSplineOrder(k), Float64)
end
types = (Int32, ComplexF32, SVector{2, Float32})
@testset "T = $T" for T types
test_interpolation(BSplineOrder(4), T)
end
end

2 comments on commit 8d851d2

@jipolanco
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/58384

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.11.0 -m "<description of version>" 8d851d2691c9b596c93c0bc6962437b472513d61
git push origin v0.11.0

Please sign in to comment.