Skip to content

Commit

Permalink
Merge pull request #550 from JuliaAI/save-restore
Browse files Browse the repository at this point in the history
Overload `save` and `restore` to address serialization issue for `BinaryThresholdPredictor`
  • Loading branch information
ablaom authored Mar 7, 2024
2 parents b0029df + 5dc7eb2 commit 55e10b5
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 2 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ jobs:
version:
- '1.6'
- '1'
- 'nightly'
os:
- ubuntu-latest
arch:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJModels"
uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.16.15"
version = "0.16.16"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
14 changes: 14 additions & 0 deletions src/builtins/ThresholdPredictors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,20 @@ function _predict_threshold(yhat::UnivariateFiniteArray{S,V,R,P,N},
end


## SERIALIZATION

function MMI.save(model::ThresholdUnion, fitresult)
atomic_fitresult, threshold = fitresult
atom = model.model
return MMI.save(atom, atomic_fitresult), threshold
end
function MMI.restore(model::ThresholdUnion, serializable_fitresult)
atomic_serializable_fitresult, threshold = serializable_fitresult
atom = model.model
return MMI.restore(atom, atomic_serializable_fitresult), threshold
end


## TRAITS

# Note: input traits are inherited from the wrapped model
Expand Down
55 changes: 55 additions & 0 deletions test/builtins/ThresholdPredictors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,61 @@ MMI.input_scitype(::Type{<:NaiveClassifier}) = Table(Continuous)
mode(Distributions.fit(MLJBase.UnivariateFinite, y[I])), MLJBase.nrows(X)
)
end


# define a probabilistic classifier with non-persistent `fitresult`, but which addresses
# this by overloading `save`/`restore`:
thing = []
struct EphemeralClassifier <: MLJBase.Probabilistic end
function MLJBase.fit(::EphemeralClassifier, verbosity, X, y)
# if I serialize/deserialized `thing` then `id` below changes:
id = objectid(thing)
p = Distributions.fit(UnivariateFinite, y)
fitresult = (thing, id, p)
return fitresult, nothing, NamedTuple()
end
function MLJBase.predict(::EphemeralClassifier, fitresult, X)
thing, id, p = fitresult
return id == objectid(thing) ? fill(p, MLJBase.nrows(X)) :
throw(ErrorException("dead fitresult"))
end
MLJBase.target_scitype(::Type{<:EphemeralClassifier}) = AbstractVector{OrderedFactor{2}}
function MLJBase.save(::EphemeralClassifier, fitresult)
thing, _, p = fitresult
return (thing, p)
end
function MLJBase.restore(::EphemeralClassifier, serialized_fitresult)
thing, p = serialized_fitresult
id = objectid(thing)
return (thing, id, p)
end

# X, y = (; x = rand(8)), categorical(collect("OXXXXOOX"), ordered=true)
# mach = machine(EphemeralClassifier(), X, y) |> fit!
# io = IOBuffer()
# MLJBase.save(io, mach)
# seekstart(io)
# mach2 = machine(io)
# predict(mach2, X)

@testset "serialization for atomic models with non-persistent fitresults" begin
# https://github.com/alan-turing-institute/MLJ.jl/issues/1099
X, y = (; x = rand(8)), categorical(collect("OXXXXOOX"), ordered=true)
deterministic_classifier = BinaryThresholdPredictor(
EphemeralClassifier(),
threshold=0.5,
)
mach = MLJBase.machine(deterministic_classifier, X, y)
MLJBase.fit!(mach, verbosity=0)
yhat = MLJBase.predict(mach, MLJBase.selectrows(X, 1:2))
io = IOBuffer()
MLJBase.save(io, mach)
seekstart(io)
mach2 = MLJBase.machine(io)
close(io)
@test MLJBase.predict(mach2, (; x = rand(2))) == yhat
end

end # module

true

0 comments on commit 55e10b5

Please sign in to comment.