diff --git a/src/Collocation/cyclic.jl b/src/Collocation/cyclic.jl index 9470a0b5..e655a0af 100644 --- a/src/Collocation/cyclic.jl +++ b/src/Collocation/cyclic.jl @@ -127,9 +127,10 @@ function LinearAlgebra.ldiv!(x::AbstractVector, A::CyclicTridiagonalMatrix, d::A γ = sqrt(ac) # value required to perturb b[1] and b[n] with the same offset u = x # alias to avoid allocations - fill!(u, 0) - u[1] = γ - u[n] = cₙ + T = eltype(u) + fill!(u, zero(T)) + u[1] = convert_scalar(T, γ) # if T is e.g. a SVector{2}, this is the SVector [γ, γ] + u[n] = convert_scalar(T, cₙ) v1 = 1 vn = a₁ / γ @@ -151,14 +152,17 @@ function LinearAlgebra.ldiv!(x::AbstractVector, A::CyclicTridiagonalMatrix, d::A vy = v1 * y[1] + vn * y[n] vq = v1 * q[1] + vn * q[n] - α = vy / (1 + vq) + α = @. vy / (1 + vq) # broadcast in case T <: StaticArray for i ∈ eachindex(x) - @inbounds x[i] = y[i] - α * q[i] + @inbounds x[i] = y[i] - α .* q[i] # broadcast in case T <: StaticArray end x end +convert_scalar(::Type{T}, x::Number) where {T <: Number} = T(x) +convert_scalar(::Type{T}, x::Number) where {T <: AbstractArray} = fill(x, T) # typically when T is a StaticArray + # Simultaneously solve M (non-cyclic) tridiagonal linear systems using Thomas algorithm. # Note that xs[i] and ds[i] can be aliased. @fastmath function solve_thomas!( diff --git a/test/periodic.jl b/test/periodic.jl index ac8dc714..e5e3811b 100644 --- a/test/periodic.jl +++ b/test/periodic.jl @@ -1,8 +1,29 @@ using BSplineKit +using StaticArrays using LinearAlgebra using Random using Test +# For now this only works for cubic splines. +function test_parametric_curve(ord::BSplineOrder{4}) + L = 2π + θs = range(0, L; length = 6)[1:5] # input angles in [0, 2π) + points = [SVector(cos(θ), sin(θ)) for θ ∈ θs] + S = interpolate(θs, copy(points), ord, Periodic(L)) + a, b = boundaries(basis(S)) + @test a == 0 + @test b == L + # Verify continuity of the curve + S′ = Derivative(1) * S + S″ = Derivative(2) * S + S‴ = Derivative(3) * S + @test S(a) == S(b) + @test S′(a) == S′(b) + @test S″(a) == S″(b) + @test S‴(a) == S‴(b) + nothing +end + function test_periodic_splines(ord::BSplineOrder) k = order(ord) L = 1 # period @@ -171,6 +192,10 @@ function test_periodic_splines(ord::BSplineOrder) end end + if k == 4 + @testset "Parametric closed curve" test_parametric_curve(ord) + end + nothing end