diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cac96abe..9b5d87fb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,6 @@ jobs: version: - '1.6' - '1' - - 'nightly' os: - ubuntu-latest arch: diff --git a/Project.toml b/Project.toml index e5d1e9a8..63ba7cfc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJModels" uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" authors = ["Anthony D. Blaom "] -version = "0.16.15" +version = "0.16.16" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" diff --git a/src/builtins/ThresholdPredictors.jl b/src/builtins/ThresholdPredictors.jl index c8c85593..a64e08bf 100644 --- a/src/builtins/ThresholdPredictors.jl +++ b/src/builtins/ThresholdPredictors.jl @@ -305,6 +305,20 @@ function _predict_threshold(yhat::UnivariateFiniteArray{S,V,R,P,N}, end +## SERIALIZATION + +function MMI.save(model::ThresholdUnion, fitresult) + atomic_fitresult, threshold = fitresult + atom = model.model + return MMI.save(atom, atomic_fitresult), threshold +end +function MMI.restore(model::ThresholdUnion, serializable_fitresult) + atomic_serializable_fitresult, threshold = serializable_fitresult + atom = model.model + return MMI.restore(atom, atomic_serializable_fitresult), threshold +end + + ## TRAITS # Note: input traits are inherited from the wrapped model diff --git a/test/builtins/ThresholdPredictors.jl b/test/builtins/ThresholdPredictors.jl index 069add61..16a81f25 100644 --- a/test/builtins/ThresholdPredictors.jl +++ b/test/builtins/ThresholdPredictors.jl @@ -268,6 +268,61 @@ MMI.input_scitype(::Type{<:NaiveClassifier}) = Table(Continuous) mode(Distributions.fit(MLJBase.UnivariateFinite, y[I])), MLJBase.nrows(X) ) end + + +# define a probabilistic classifier with non-persistent `fitresult`, but which addresses +# this by overloading `save`/`restore`: +thing = [] +struct EphemeralClassifier <: MLJBase.Probabilistic end +function MLJBase.fit(::EphemeralClassifier, verbosity, X, y) + # if I serialize/deserialized `thing` then `id` below changes: + id = objectid(thing) + p = Distributions.fit(UnivariateFinite, y) + fitresult = (thing, id, p) + return fitresult, nothing, NamedTuple() +end +function MLJBase.predict(::EphemeralClassifier, fitresult, X) + thing, id, p = fitresult + return id == objectid(thing) ? fill(p, MLJBase.nrows(X)) : + throw(ErrorException("dead fitresult")) +end +MLJBase.target_scitype(::Type{<:EphemeralClassifier}) = AbstractVector{OrderedFactor{2}} +function MLJBase.save(::EphemeralClassifier, fitresult) + thing, _, p = fitresult + return (thing, p) +end +function MLJBase.restore(::EphemeralClassifier, serialized_fitresult) + thing, p = serialized_fitresult + id = objectid(thing) + return (thing, id, p) +end + +# X, y = (; x = rand(8)), categorical(collect("OXXXXOOX"), ordered=true) +# mach = machine(EphemeralClassifier(), X, y) |> fit! +# io = IOBuffer() +# MLJBase.save(io, mach) +# seekstart(io) +# mach2 = machine(io) +# predict(mach2, X) + +@testset "serialization for atomic models with non-persistent fitresults" begin + # https://github.com/alan-turing-institute/MLJ.jl/issues/1099 + X, y = (; x = rand(8)), categorical(collect("OXXXXOOX"), ordered=true) + deterministic_classifier = BinaryThresholdPredictor( + EphemeralClassifier(), + threshold=0.5, + ) + mach = MLJBase.machine(deterministic_classifier, X, y) + MLJBase.fit!(mach, verbosity=0) + yhat = MLJBase.predict(mach, MLJBase.selectrows(X, 1:2)) + io = IOBuffer() + MLJBase.save(io, mach) + seekstart(io) + mach2 = MLJBase.machine(io) + close(io) + @test MLJBase.predict(mach2, (; x = rand(2))) == yhat +end + end # module true