Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert Result Type, Proof of ResultsIndependent, Reduce Brittleness in FisherYates #161

Merged
merged 14 commits into from
Mar 6, 2024
6 changes: 3 additions & 3 deletions audit.log
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
src/Distributions/Uniform/Correctness.dfy(31,17): UniformFullCorrectness: Declaration has explicit `{:axiom}` attribute.
src/Distributions/Uniform/Correctness.dfy(36,17): SampleCoin: Declaration has explicit `{:axiom}` attribute.
src/Distributions/Uniform/Model.dfy(19,33): Sample: Declaration has explicit `{:axiom}` attribute.
src/Distributions/Uniform/Model.dfy(46,17): IntervalSampleIsMeasurePreserving: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Independence.dfy(30,27): IsIndep: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Independence.dfy(55,6): ResultsIndependent: Definition has `assume {:axiom}` statement in body.
src/ProbabilisticProgramming/Independence.dfy(60,17): IsIndepImpliesIsIndepFunction: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Independence.dfy(64,17): MapIsIndep: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Independence.dfy(70,17): IsIndepImpliesIsIndepFunction: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Independence.dfy(74,17): MapIsIndep: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/RandomSource.dfy(50,17): ProbIsProbabilityMeasure: Declaration has explicit `{:axiom}` attribute.
8 changes: 5 additions & 3 deletions src/Distributions/Uniform/Model.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ module Uniform.Model {
opaque ghost function {:axiom} Sample(n: nat): (h: Monad.Hurd<nat>)
requires n > 0
ensures Independence.IsIndep(h)
ensures forall s :: h(s).Result? ==> 0 <= h(s).value < n
ensures forall s :: 0 <= h(s).value < n

ghost function IntervalSample(a: int, b: int): (f: Monad.Hurd<int>)
requires a < b
Expand All @@ -33,16 +33,18 @@ module Uniform.Model {

lemma SampleBound(n: nat, s: Rand.Bitstream)
requires n > 0
requires Sample(n)(s).Result?
ensures 0 <= Sample(n)(s).value < n
{}

lemma IntervalSampleBound(a: int, b: int, s: Rand.Bitstream)
requires a < b
requires IntervalSample(a, b)(s).Result?
ensures a <= IntervalSample(a, b)(s).value < b
{
SampleBound(b-a, s);
}

lemma {:axiom} IntervalSampleIsMeasurePreserving(a: int, b: int)
requires a < b
ensures Measures.IsMeasurePreserving(Rand.eventSpace, Rand.prob, Rand.eventSpace, Rand.prob, s => IntervalSample(a, b)(s).rest)

}
12 changes: 11 additions & 1 deletion src/ProbabilisticProgramming/Independence.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ module Independence {
)
requires hIndep: IsIndepFunction(h)
requires bMeasurable: bSeeds in Rand.eventSpace
requires hIsMeasurePreserving: Measures.IsMeasurePreserving(Rand.eventSpace, Rand.prob, Rand.eventSpace, Rand.prob, s => h(s).rest)
ensures Rand.prob(Monad.BitstreamsWithValueIn(h, aSet) * Monad.BitstreamsWithRestIn(h, bSeeds)) == Rand.prob(Monad.BitstreamsWithValueIn(h, aSet)) * Rand.prob(bSeeds)
{
var aSeeds := Monad.BitstreamsWithValueIn(h, aSet);
Expand All @@ -52,7 +53,16 @@ module Independence {
assert Measures.AreIndepEvents(Rand.eventSpace, Rand.prob, aSeeds, restBSeeds);
}
assert Rand.prob(restBSeeds) == Rand.prob(bSeeds) by {
assume {:axiom} false; // TODO
calc {
Rand.prob(restBSeeds);
Rand.prob(Monad.BitstreamsWithRestIn(h, bSeeds));
{ assert Monad.BitstreamsWithRestIn(h, bSeeds) == iset s | h(s).rest in bSeeds; }
Rand.prob(iset s | h(s).rest in bSeeds);
{ assert (iset s | h(s).rest in bSeeds) == Measures.PreImage(s => h(s).rest, bSeeds); }
Rand.prob(Measures.PreImage(s => h(s).rest, bSeeds));
{ reveal bMeasurable; reveal hIsMeasurePreserving; }
Rand.prob(bSeeds);
}
}
}

Expand Down
59 changes: 11 additions & 48 deletions src/ProbabilisticProgramming/Monad.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,14 @@ module Monad {
// The result of a probabilistic computation on a bitstream.
// It either consists of the computed value and the (unconsumed) rest of the bitstream or indicates nontermination.
// It differs from Hurd's definition in that the result can be nontermination, which Hurd does not model explicitly.
datatype Result<A> =
| Result(value: A, rest: Rand.Bitstream)
| Diverging
datatype Result<A> = Result(value: A, rest: Rand.Bitstream)
{
function Map<B>(f: A -> B): Result<B> {
match this
case Diverging => Diverging
case Result(value, rest) => Result(f(value), rest)
Result(f(value), rest)
}

function Bind<B>(f: A -> Hurd<B>): Result<B> {
match this
case Diverging => Diverging
case Result(value, rest) => f(value)(rest)
f(value)(rest)
}

ghost predicate In(s: iset<A>) {
Expand All @@ -43,44 +37,30 @@ module Monad {
}

predicate Satisfies(property: A -> bool) {
match this
case Diverging => false
case Result(value, _) => property(value)
property(value)
}

ghost predicate RestIn(s: iset<Rand.Bitstream>) {
RestSatisfies(r => r in s)
}

predicate RestSatisfies(property: Rand.Bitstream -> bool) {
match this
case Diverging => false
case Result(_, rest) => property(rest)
property(rest)
}

predicate IsFailure() {
Diverging?
}

function PropagateFailure<B>(): Result<B>
requires Diverging?
{
Diverging
}

function Extract(): (A, Rand.Bitstream)
requires Result?
function Extract(): (x: (A, Rand.Bitstream))
ensures this == Result(x.0, x.1)
{
(this.value, this.rest)
}
}

ghost function Values<A>(results: iset<Result<A>>): iset<A> {
iset r <- results | r.Result? :: r.value
iset r <- results :: r.value
}

ghost function Rests<A>(results: iset<Result<A>>): iset<Rand.Bitstream> {
iset r <- results | r.Result? :: r.rest
iset r <- results :: r.rest
}

ghost function ResultEventSpace<A(!new)>(eventSpace: iset<iset<A>>): iset<iset<Result<A>>> {
Expand All @@ -92,11 +72,11 @@ module Monad {
ghost const natResultEventSpace: iset<iset<Result<nat>>> := ResultEventSpace(Measures.natEventSpace)

ghost function ResultsWithValueIn<A(!new)>(values: iset<A>): iset<Result<A>> {
iset result: Result<A> | result.Result? && result.value in values
iset result: Result<A> | result.value in values
}

ghost function ResultsWithRestIn<A(!new)>(rests: iset<Rand.Bitstream>): iset<Result<A>> {
iset result: Result<A> | result.Result? && result.rest in rests
iset result: Result<A> | result.rest in rests
}

ghost function BitstreamsWithValueIn<A(!new)>(h: Hurd<A>, aSet: iset<A>): iset<Rand.Bitstream> {
Expand All @@ -112,30 +92,13 @@ module Monad {
(s: Rand.Bitstream) => f(s).Bind(g)
}

function BindAlternative<A,B>(f: Hurd<A>, g: A -> Hurd<B>): (h: Hurd<B>)
ensures forall s :: h(s) == Bind(f, g)(s)
{
(s: Rand.Bitstream) =>
var (a, s') :- f(s);
g(a)(s')
}

// Equation (2.42)
const Coin: Hurd<bool> := s => Result(Rand.Head(s), Rand.Tail(s))

function Composition<A,B,C>(f: A -> Hurd<B>, g: B -> Hurd<C>): A -> Hurd<C> {
(a: A) => Bind(f(a), g)
}

function CompositionAlternative<A(!new),B,C>(f: A -> Hurd<B>, g: B -> Hurd<C>): (h: A -> Hurd<C>)
ensures forall a, s :: h(a)(s) == Composition(f, g)(a)(s)
{
(a: A) =>
(s: Rand.Bitstream) =>
var (b, s') :- f(a)(s);
g(b)(s')
}

// Equation (3.3)
function Return<A>(a: A): Hurd<A> {
(s: Rand.Bitstream) => Result(a, s)
Expand Down
97 changes: 67 additions & 30 deletions src/Util/FisherYates/Correctness.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ module FisherYates.Correctness {
requires i <= |xs|
requires i <= |p|
{
iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..]
iset s | Model.Shuffle(xs, i)(s).value[i..] == p[i..]
}

ghost predicate CorrectnessPredicate<T(!new)>(xs: seq<T>, p: seq<T>, i: nat)
Expand Down Expand Up @@ -64,7 +64,7 @@ module FisherYates.Correctness {
{
var e := iset s | Model.Shuffle(xs)(s).Equals(p);
var i := 0;
var e' := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..];
var e' := iset s | Model.Shuffle(xs, i)(s).value[i..] == p[i..];
assert e == e';
assert |xs| == |p| by {
Model.PermutationsPreserveCardinality(xs, p);
Expand All @@ -81,7 +81,7 @@ module FisherYates.Correctness {
requires multiset(p[i..]) == multiset(xs[i..])
ensures CorrectnessPredicate(xs, p, i)
{
var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..];
var e := iset s | Model.Shuffle(xs, i)(s).value[i..] == p[i..];
if |xs[i..]| <= 1 {
CorrectnessFisherYatesUniqueElementsGeneralLeq1(xs, p, i);
} else {
Expand All @@ -100,16 +100,16 @@ module FisherYates.Correctness {
ensures CorrectnessPredicate(xs, p, i)
{
Model.PermutationsPreserveCardinality(p[i..], xs[i..]);
var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..];
var e := iset s | Model.Shuffle(xs, i)(s).value[i..] == p[i..];
assert e == Measures.SampleSpace() by {
forall s
ensures s in e
{
calc {
s in e;
Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..];
Model.Shuffle(xs, i)(s).value[i..] == p[i..];
{ assert Model.Shuffle(xs, i)(s) == Monad.Return(xs)(s); }
Monad.Return(xs)(s).Result? && Monad.Return(xs)(s).value[i..] == p[i..];
Monad.Return(xs)(s).value[i..] == p[i..];
{ assert Monad.Return(xs)(s).value == xs; }
xs[i..] == p[i..];
if |xs[i..]| == 0 then [] == p[i..] else [xs[i]] == p[i..];
Expand Down Expand Up @@ -147,9 +147,12 @@ module FisherYates.Correctness {
ensures CorrectnessPredicate(xs, p, i)
{
Model.PermutationsPreserveCardinality(p[i..], xs[i..]);
var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..];
var e := iset s | Model.Shuffle(xs, i)(s).value[i..] == p[i..];
assert |xs| > i + 1;
var h := Uniform.Model.IntervalSample(i, |xs|);
assert hIsMeasurePreserving: Measures.IsMeasurePreserving(Rand.eventSpace, Rand.prob, Rand.eventSpace, Rand.prob, s => h(s).rest) by {
Uniform.Model.IntervalSampleIsMeasurePreserving(i, |xs|);
}
assert HIsIndependent: Independence.IsIndepFunction(h) by {
Uniform.Correctness.IntervalSampleIsIndep(i, |xs|);
Independence.IsIndepImpliesIsIndepFunction(h);
Expand Down Expand Up @@ -238,33 +241,68 @@ module FisherYates.Correctness {
assert DecomposeE: e == Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e') by {
DecomposeE(xs, ys, p, i, j, h, A, e, e');
}

assert CorrectnessPredicate(xs, p, i) by {
reveal DecomposeE;
reveal HIsIndependent;
reveal BitStreamsInA;
assert e' in Rand.eventSpace && Rand.prob(e') == 1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real) by {
assert e' in Rand.eventSpace by {
assert CorrectnessPredicate(ys, p, i+1) by { reveal InductionHypothesis; }
assert e' == CorrectnessConstructEvent(ys, p, i+1);
}
calc {
Rand.prob(e');
{ assert CorrectnessPredicate(ys, p, i+1) by { reveal InductionHypothesis; }
assert e' == CorrectnessConstructEvent(ys, p, i+1); }
1.0 / (NatArith.FactorialTraditional(|ys|-(i+1)) as real);
{ assert |xs| == |ys|;
assert |ys|-(i+1) == |xs|-(i+1);
assert NatArith.FactorialTraditional(|ys|-(i+1)) == NatArith.FactorialTraditional(|xs|-(i+1));
assert (NatArith.FactorialTraditional(|ys|-(i+1)) as real) == (NatArith.FactorialTraditional(|xs|-(i+1)) as real); }
1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real);
}
reveal InductionHypothesis;
reveal hIsMeasurePreserving;
CorrectnessFisherYatesUniqueElementsGeneralGreater1Helper(xs, ys, p, i, j, h, A, e, e');
}

}

lemma CorrectnessFisherYatesUniqueElementsGeneralGreater1Helper<T(!new)>(xs: seq<T>, ys: seq<T>, p: seq<T>, i: nat, j: nat, h: Monad.Hurd<int>, A: iset<int>, e: iset<Rand.Bitstream>, e': iset<Rand.Bitstream>)
decreases |xs| - i
requires i <= |xs|
requires i <= |p|
requires forall a, b | i <= a < b < |xs| :: xs[a] != xs[b]
requires |xs| == |p|
requires multiset(p[i..]) == multiset(xs[i..])
requires |xs[i..]| > 1
requires i <= j < |xs| && xs[j] == p[i]
requires |xs| == |ys|
requires ys == Model.Swap(xs, i, j)
requires e == CorrectnessConstructEvent(xs, p, i)
requires e' == CorrectnessConstructEvent(ys, p, i+1)
requires DecomposeE: e == Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e')
requires HIsIndependent: Independence.IsIndepFunction(h)
requires BitStreamsInA: Monad.BitstreamsWithValueIn(h, A) == (iset s | Uniform.Model.IntervalSample(i, |xs|)(s).Equals(j))
requires InductionHypothesis: CorrectnessPredicate(ys, p, i+1)
requires hIsMeasurePreserving: Measures.IsMeasurePreserving(Rand.eventSpace, Rand.prob, Rand.eventSpace, Rand.prob, s => h(s).rest)
ensures CorrectnessPredicate(xs, p, i)
{
reveal DecomposeE;
reveal HIsIndependent;
reveal BitStreamsInA;
assert e' in Rand.eventSpace && Rand.prob(e') == 1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real) by {
assert e' in Rand.eventSpace by {
assert CorrectnessPredicate(ys, p, i+1) by { reveal InductionHypothesis; }
assert e' == CorrectnessConstructEvent(ys, p, i+1);
}
ProbabilityOfE(xs, ys, p, i, j, h, A, e, e');
calc {
Rand.prob(e');
{ assert CorrectnessPredicate(ys, p, i+1) by { reveal InductionHypothesis; }
assert e' == CorrectnessConstructEvent(ys, p, i+1); }
1.0 / (NatArith.FactorialTraditional(|ys|-(i+1)) as real);
{ assert |xs| == |ys|;
assert |ys|-(i+1) == |xs|-(i+1);
assert NatArith.FactorialTraditional(|ys|-(i+1)) == NatArith.FactorialTraditional(|xs|-(i+1));
assert (NatArith.FactorialTraditional(|ys|-(i+1)) as real) == (NatArith.FactorialTraditional(|xs|-(i+1)) as real); }
1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real);
}
}
reveal hIsMeasurePreserving;
assert e in Rand.eventSpace by {
EInEventSpace(xs, p, h, A, e, e');
}

assert Rand.prob(e) == 1.0 / (NatArith.FactorialTraditional(|xs|-i) as real) by {
ProbabilityOfE(xs, ys, p, i, j, h, A, e, e');
}
}


lemma BitStreamsInA<T(!new)>(xs: seq<T>, p: seq<T>, i: nat, j: nat, h: Monad.Hurd<int>, A: iset<int>)
requires i <= |xs|
requires i <= |p|
Expand Down Expand Up @@ -302,8 +340,8 @@ module FisherYates.Correctness {
requires A == iset{j}
requires h == Uniform.Model.IntervalSample(i, |xs|)
requires ys == Model.Swap(xs, i, j)
requires e == iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..]
requires e' == iset s | Model.Shuffle(ys, i+1)(s).Result? && Model.Shuffle(ys, i+1)(s).value[i+1..] == p[i+1..]
requires e == iset s | Model.Shuffle(xs, i)(s).value[i..] == p[i..]
requires e' == iset s | Model.Shuffle(ys, i+1)(s).value[i+1..] == p[i+1..]
ensures e == Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e')
{
assert forall s :: s in e <==> s in Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e') by {
Expand All @@ -313,7 +351,6 @@ module FisherYates.Correctness {
if s in e {
var zs := Model.Shuffle(xs, i)(s).value;
assert zs[i..] == p[i..];
assert h(s).Result?;
var k := Uniform.Model.IntervalSample(i, |xs|)(s).value;
Uniform.Model.IntervalSampleBound(i, |xs|, s);
var s' := Uniform.Model.IntervalSample(i, |xs|)(s).rest;
Expand All @@ -332,7 +369,6 @@ module FisherYates.Correctness {
assert k in A;
}
assert s in Monad.BitstreamsWithRestIn(h, e') by {
assert Model.Shuffle(ys, i+1)(s').Result?;
assert Model.Shuffle(ys, i+1)(s').value[i+1..] == p[i+1..];
}
assert s in Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e');
Expand Down Expand Up @@ -491,6 +527,7 @@ module FisherYates.Correctness {
requires |xs| == |ys|
requires DecomposeE: e == Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e')
requires HIsIndependent: Independence.IsIndepFunction(h)
requires hIsMeasurePreserving: Measures.IsMeasurePreserving(Rand.eventSpace, Rand.prob, Rand.eventSpace, Rand.prob, s => h(s).rest)
requires InductionHypothesis: Rand.prob(e') == 1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real)
requires BitStreamsInA: Monad.BitstreamsWithValueIn(h, A) == (iset s | Uniform.Model.IntervalSample(i, |xs|)(s).Equals(j))
ensures
Expand All @@ -500,7 +537,7 @@ module FisherYates.Correctness {
Rand.prob(e);
{ reveal DecomposeE; }
Rand.prob(Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e'));
{ reveal HIsIndependent; reveal InductionHypothesis; Independence.ResultsIndependent(h, A, e'); }
{ reveal HIsIndependent; reveal InductionHypothesis; reveal hIsMeasurePreserving; Independence.ResultsIndependent(h, A, e'); }
Rand.prob(Monad.BitstreamsWithValueIn(h, A)) * Rand.prob(e');
{ assert Rand.prob(Monad.BitstreamsWithValueIn(h, A)) == Rand.prob(iset s | Uniform.Model.IntervalSample(i, |xs|)(s).Equals(j)) by { reveal BitStreamsInA; } }
Rand.prob(iset s | Uniform.Model.IntervalSample(i, |xs|)(s).Equals(j)) * Rand.prob(e');
Expand Down
3 changes: 3 additions & 0 deletions src/Util/FisherYates/Implementation.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ module FisherYates.Implementation {
assert prevASeq == a[..]; // ghost
Swap(a, i, j);
}
} else {
assert prevASeq == a[..]; // ghost
}

}

method Swap<T>(a: array<T>, i: nat, j: nat)
Expand Down
Loading