From ecdc5e62b903bdd10ff3205250a60ff9e62f6ce6 Mon Sep 17 00:00:00 2001 From: Stefan Zetzsche <120379523+stefan-aws@users.noreply.github.com> Date: Mon, 30 Oct 2023 14:38:14 +0000 Subject: [PATCH] Separate modules for equivalence proofs (#108) Moving all the code that is referred to in `Implementation` to establish the equivalence between the imperative and functional model into a separate module `Equivalence`. The code in the new module should not be referred to by any other module than `Implementation`. By submitting this pull request, I confirm that my contribution is made under the terms of the [MIT license](https://github.com/dafny-lang/dafny/blob/master/LICENSE.txt). --- audit.log | 2 +- .../BernoulliExpNeg/Equivalence.dfy | 85 +++++++++++ .../BernoulliExpNeg/Implementation.dfy | 56 +------ src/Distributions/BernoulliExpNeg/Model.dfy | 16 +- src/Distributions/Uniform/Equivalence.dfy | 20 +++ src/Distributions/Uniform/Implementation.dfy | 3 +- src/Distributions/Uniform/Model.dfy | 9 -- .../UniformPowerOfTwo/Equivalence.dfy | 137 ++++++++++++++++++ .../UniformPowerOfTwo/Implementation.dfy | 5 +- src/Distributions/UniformPowerOfTwo/Model.dfy | 125 +--------------- 10 files changed, 258 insertions(+), 200 deletions(-) create mode 100644 src/Distributions/BernoulliExpNeg/Equivalence.dfy create mode 100644 src/Distributions/Uniform/Equivalence.dfy create mode 100644 src/Distributions/UniformPowerOfTwo/Equivalence.dfy diff --git a/audit.log b/audit.log index 25dc5f83..ae474d2e 100644 --- a/audit.log +++ b/audit.log @@ -2,7 +2,7 @@ src/Distributions/BernoulliExpNeg/Correctness.dfy(13,17): Correctness: Declarati src/Distributions/BernoulliExpNeg/Correctness.dfy(17,17): SampleIsIndep: Declaration has explicit `{:axiom}` attribute. src/Distributions/BernoulliExpNeg/Model.dfy(104,17): GammaLe1LoopTerminatesAlmostSurely: Declaration has explicit `{:axiom}` attribute. src/Distributions/Coin/Interface.dfy(21,6): CoinSample: Definition has `assume {:axiom}` statement in body. -src/Distributions/Uniform/Implementation.dfy(46,6): UniformSample: Definition has `assume {:axiom}` statement in body. +src/Distributions/Uniform/Implementation.dfy(47,6): UniformSample: Definition has `assume {:axiom}` statement in body. src/Math/Analysis/Reals.dfy(35,17): LeastUpperBoundProperty: Declaration has explicit `{:axiom}` attribute. src/Math/Exponential.dfy(11,17): EvalOne: Declaration has explicit `{:axiom}` attribute. src/Math/Exponential.dfy(2,26): Exp: Declaration has explicit `{:axiom}` attribute. diff --git a/src/Distributions/BernoulliExpNeg/Equivalence.dfy b/src/Distributions/BernoulliExpNeg/Equivalence.dfy new file mode 100644 index 00000000..05e51d44 --- /dev/null +++ b/src/Distributions/BernoulliExpNeg/Equivalence.dfy @@ -0,0 +1,85 @@ +/******************************************************************************* + * Copyright by the contributors to the Dafny Project + * SPDX-License-Identifier: MIT + *******************************************************************************/ + +module BernoulliExpNeg.Equivalence { + import Rationals + import Rand + import Monad + import Model + import Loops + import Bernoulli + + /************ + Definitions + ************/ + + opaque ghost predicate CaseLe1LoopInvariant(gamma: Rationals.Rational, oldS: Rand.Bitstream, a: bool, k: nat, s: Rand.Bitstream) + requires 0 <= gamma.numer <= gamma.denom + { + Model.GammaLe1Loop(gamma)((true, 0))(oldS) == Model.GammaLe1Loop(gamma)((a, k))(s) + } + + /******* + Lemmas + *******/ + + lemma GammaLe1LoopUnroll(gamma: Rationals.Rational, ak: (bool, nat), s: Rand.Bitstream) + requires 0 <= gamma.numer <= gamma.denom + requires ak.0 + ensures Model.GammaLe1Loop(gamma)(ak)(s) == Monad.Bind(Model.GammaLe1LoopIter(gamma)(ak), Model.GammaLe1Loop(gamma))(s) + { + Model.GammaLe1LoopTerminatesAlmostSurely(gamma); + calc { + Model.GammaLe1Loop(gamma)(ak)(s); + { reveal Model.GammaLe1Loop(); } + Loops.While(Model.GammaLe1LoopCondition, Model.GammaLe1LoopIter(gamma), ak)(s); + { reveal Model.GammaLe1Loop(); + Loops.WhileUnroll(Model.GammaLe1LoopCondition, Model.GammaLe1LoopIter(gamma), ak, s); } + Monad.Bind(Model.GammaLe1LoopIter(gamma)(ak), Model.GammaLe1Loop(gamma))(s); + } + } + + lemma EnsureCaseLe1LoopInvariantOnEntry(gamma: Rationals.Rational, s: Rand.Bitstream) + requires 0 <= gamma.numer <= gamma.denom + ensures CaseLe1LoopInvariant(gamma, s, true, 0, s) + { + reveal CaseLe1LoopInvariant(); + } + + lemma EnsureCaseLe1LoopInvariantMaintained(gamma: Rationals.Rational, oldS: Rand.Bitstream, k: nat, s: Rand.Bitstream, a': bool, k': nat, s': Rand.Bitstream) + requires 0 <= gamma.numer <= gamma.denom + requires k' == k + 1 + requires inv: CaseLe1LoopInvariant(gamma, oldS, true, k, s) + requires bernoulli: Monad.Result(a' , s') == Bernoulli.Model.Sample(gamma.numer, k' * gamma.denom)(s) + ensures CaseLe1LoopInvariant(gamma, oldS, a', k', s') + { + assert iter: Monad.Result((a', k'), s') == Model.GammaLe1LoopIter(gamma)((true, k))(s) by { + reveal bernoulli; + } + calc { + Model.GammaLe1Loop(gamma)((true, 0))(oldS); + { reveal CaseLe1LoopInvariant(); reveal inv; } + Model.GammaLe1Loop(gamma)((true, k))(s); + { reveal iter; GammaLe1LoopUnroll(gamma, (true, k), s); } + Model.GammaLe1Loop(gamma)((a', k'))(s'); + } + reveal CaseLe1LoopInvariant(); + } + + lemma EnsureCaseLe1PostCondition(gamma: Rationals.Rational, oldS: Rand.Bitstream, k: nat, s: Rand.Bitstream, c: bool) + requires 0 <= gamma.numer <= gamma.denom + requires CaseLe1LoopInvariant(gamma, oldS, false, k, s) + requires c <==> (k % 2 == 1) + ensures Monad.Result(c, s) == Model.SampleGammaLe1(gamma)(oldS) + { + calc { + Model.GammaLe1Loop(gamma)((true, 0))(oldS); + { reveal CaseLe1LoopInvariant(); } + Model.GammaLe1Loop(gamma)((false, k))(s); + { reveal Model.GammaLe1Loop(); } + Monad.Result((false, k), s); + } + } +} diff --git a/src/Distributions/BernoulliExpNeg/Implementation.dfy b/src/Distributions/BernoulliExpNeg/Implementation.dfy index d750f139..75bd0796 100644 --- a/src/Distributions/BernoulliExpNeg/Implementation.dfy +++ b/src/Distributions/BernoulliExpNeg/Implementation.dfy @@ -11,6 +11,7 @@ module BernoulliExpNeg.Implementation { import Interface import Model import Bernoulli + import Equivalence trait {:termination false} Trait extends Interface.Trait { @@ -47,68 +48,21 @@ module BernoulliExpNeg.Implementation { { var k: nat := 0; var a := true; - EnsureCaseLe1LoopInvariantOnEntry(gamma, s); + Equivalence.EnsureCaseLe1LoopInvariantOnEntry(gamma, s); while a decreases * - invariant CaseLe1LoopInvariant(gamma, old(s), a, k, s) + invariant Equivalence.CaseLe1LoopInvariant(gamma, old(s), a, k, s) { ghost var prevK: nat := k; ghost var prevS := s; k := k + 1; Helper.MulMonotonic(1, gamma.denom, k, gamma.denom); a := BernoulliSample(Rationals.Rational(gamma.numer, k * gamma.denom)); - EnsureCaseLe1LoopInvariantMaintained(gamma, old(s), prevK, prevS, a, k, s); + Equivalence.EnsureCaseLe1LoopInvariantMaintained(gamma, old(s), prevK, prevS, a, k, s); } c := k % 2 == 1; - EnsureCaseLe1PostCondition(gamma, old(s), k, s, c); + Equivalence.EnsureCaseLe1PostCondition(gamma, old(s), k, s, c); } } - opaque ghost predicate CaseLe1LoopInvariant(gamma: Rationals.Rational, oldS: Rand.Bitstream, a: bool, k: nat, s: Rand.Bitstream) - requires 0 <= gamma.numer <= gamma.denom - { - Model.GammaLe1Loop(gamma)((true, 0))(oldS) == Model.GammaLe1Loop(gamma)((a, k))(s) - } - - lemma EnsureCaseLe1LoopInvariantOnEntry(gamma: Rationals.Rational, s: Rand.Bitstream) - requires 0 <= gamma.numer <= gamma.denom - ensures CaseLe1LoopInvariant(gamma, s, true, 0, s) - { - reveal CaseLe1LoopInvariant(); - } - - lemma EnsureCaseLe1LoopInvariantMaintained(gamma: Rationals.Rational, oldS: Rand.Bitstream, k: nat, s: Rand.Bitstream, a': bool, k': nat, s': Rand.Bitstream) - requires 0 <= gamma.numer <= gamma.denom - requires k' == k + 1 - requires inv: CaseLe1LoopInvariant(gamma, oldS, true, k, s) - requires bernoulli: Monad.Result(a' , s') == Bernoulli.Model.Sample(gamma.numer, k' * gamma.denom)(s) - ensures CaseLe1LoopInvariant(gamma, oldS, a', k', s') - { - assert iter: Monad.Result((a', k'), s') == Model.GammaLe1LoopIter(gamma)((true, k))(s) by { - reveal bernoulli; - } - calc { - Model.GammaLe1Loop(gamma)((true, 0))(oldS); - { reveal CaseLe1LoopInvariant(); reveal inv; } - Model.GammaLe1Loop(gamma)((true, k))(s); - { reveal iter; Model.GammaLe1LoopUnroll(gamma, (true, k), s); } - Model.GammaLe1Loop(gamma)((a', k'))(s'); - } - reveal CaseLe1LoopInvariant(); - } - - lemma EnsureCaseLe1PostCondition(gamma: Rationals.Rational, oldS: Rand.Bitstream, k: nat, s: Rand.Bitstream, c: bool) - requires 0 <= gamma.numer <= gamma.denom - requires CaseLe1LoopInvariant(gamma, oldS, false, k, s) - requires c <==> (k % 2 == 1) - ensures Monad.Result(c, s) == Model.SampleGammaLe1(gamma)(oldS) - { - calc { - Model.GammaLe1Loop(gamma)((true, 0))(oldS); - { reveal CaseLe1LoopInvariant(); } - Model.GammaLe1Loop(gamma)((false, k))(s); - { reveal Model.GammaLe1Loop(); } - Monad.Result((false, k), s); - } - } } diff --git a/src/Distributions/BernoulliExpNeg/Model.dfy b/src/Distributions/BernoulliExpNeg/Model.dfy index 866bd5d4..5186033e 100644 --- a/src/Distributions/BernoulliExpNeg/Model.dfy +++ b/src/Distributions/BernoulliExpNeg/Model.dfy @@ -105,19 +105,5 @@ module BernoulliExpNeg.Model { requires 0 <= gamma.numer <= gamma.denom ensures Loops.WhileTerminatesAlmostSurely(GammaLe1LoopCondition, GammaLe1LoopIter(gamma)) - lemma GammaLe1LoopUnroll(gamma: Rationals.Rational, ak: (bool, nat), s: Rand.Bitstream) - requires 0 <= gamma.numer <= gamma.denom - requires ak.0 - ensures GammaLe1Loop(gamma)(ak)(s) == Monad.Bind(GammaLe1LoopIter(gamma)(ak), GammaLe1Loop(gamma))(s) - { - GammaLe1LoopTerminatesAlmostSurely(gamma); - calc { - GammaLe1Loop(gamma)(ak)(s); - { reveal GammaLe1Loop(); } - Loops.While(GammaLe1LoopCondition, GammaLe1LoopIter(gamma), ak)(s); - { reveal GammaLe1Loop(); - Loops.WhileUnroll(GammaLe1LoopCondition, GammaLe1LoopIter(gamma), ak, s); } - Monad.Bind(GammaLe1LoopIter(gamma)(ak), GammaLe1Loop(gamma))(s); - } - } + } diff --git a/src/Distributions/Uniform/Equivalence.dfy b/src/Distributions/Uniform/Equivalence.dfy new file mode 100644 index 00000000..e3e75c0c --- /dev/null +++ b/src/Distributions/Uniform/Equivalence.dfy @@ -0,0 +1,20 @@ +/******************************************************************************* + * Copyright by the contributors to the Dafny Project + * SPDX-License-Identifier: MIT + *******************************************************************************/ + +module Uniform.Equivalence { + import Rand + import Model + import Monad + import Loops + + lemma SampleUnroll(n: nat, s: Rand.Bitstream) + requires n > 0 + ensures Model.Sample(n)(s) == Monad.Bind(Model.Proposal(n), (x: nat) => if Model.Accept(n)(x) then Monad.Return(x) else Model.Sample(n))(s) + { + Model.SampleTerminates(n); + reveal Model.Sample(); + Loops.UntilUnroll(Model.Proposal(n), Model.Accept(n), s); + } +} diff --git a/src/Distributions/Uniform/Implementation.dfy b/src/Distributions/Uniform/Implementation.dfy index 3057643b..b19bd45f 100644 --- a/src/Distributions/Uniform/Implementation.dfy +++ b/src/Distributions/Uniform/Implementation.dfy @@ -8,6 +8,7 @@ module Uniform.Implementation { import UniformPowerOfTwo import Model import Interface + import Equivalence trait {:termination false} TraitFoundational extends Interface.Trait { method UniformSample(n: nat) returns (u: nat) @@ -24,7 +25,7 @@ module Uniform.Implementation { invariant Model.Sample(n)(old(s)) == Model.Sample(n)(prevS) invariant Monad.Result(u, s) == Model.Proposal(n)(prevS) { - Model.SampleUnroll(n, prevS); + Equivalence.SampleUnroll(n, prevS); prevS := s; u := UniformPowerOfTwoSample(2 * n); } diff --git a/src/Distributions/Uniform/Model.dfy b/src/Distributions/Uniform/Model.dfy index 711262a4..cf2598d6 100644 --- a/src/Distributions/Uniform/Model.dfy +++ b/src/Distributions/Uniform/Model.dfy @@ -85,13 +85,4 @@ module Uniform.Model { Loops.EnsureUntilTerminates(Proposal(n), Accept(n)); } } - - lemma SampleUnroll(n: nat, s: Rand.Bitstream) - requires n > 0 - ensures Sample(n)(s) == Monad.Bind(Proposal(n), (x: nat) => if Accept(n)(x) then Monad.Return(x) else Sample(n))(s) - { - SampleTerminates(n); - reveal Sample(); - Loops.UntilUnroll(Proposal(n), Accept(n), s); - } } diff --git a/src/Distributions/UniformPowerOfTwo/Equivalence.dfy b/src/Distributions/UniformPowerOfTwo/Equivalence.dfy new file mode 100644 index 00000000..c0d5f240 --- /dev/null +++ b/src/Distributions/UniformPowerOfTwo/Equivalence.dfy @@ -0,0 +1,137 @@ +/******************************************************************************* + * Copyright by the contributors to the Dafny Project + * SPDX-License-Identifier: MIT + *******************************************************************************/ + +module UniformPowerOfTwo.Equivalence { + import Rand + import Monad + import Helper + import Model + + /************ + Definitions + ************/ + + // A tail recursive version of Sample, closer to the imperative implementation + function SampleTailRecursive(n: nat, u: nat := 0): Monad.Hurd + requires n >= 1 + { + (s: Rand.Bitstream) => + if n == 1 then + Monad.Result(u, s) + else + SampleTailRecursive(n / 2, if Rand.Head(s) then 2*u + 1 else 2*u)(Rand.Tail(s)) + } + + /******* + Lemmas + *******/ + + // Equivalence of Sample and its tail-recursive version + lemma SampleCorrespondence(n: nat, s: Rand.Bitstream) + requires n >= 1 + ensures SampleTailRecursive(n)(s) == Model.Sample(n)(s) + { + if n == 1 { + reveal Model.Sample(); + assert SampleTailRecursive(n)(s) == Model.Sample(n)(s); + } else { + var k := Helper.Log2Floor(n); + assert Helper.Power(2, k) <= n < Helper.Power(2, k + 1) by { Helper.Power2OfLog2Floor(n); } + calc { + SampleTailRecursive(n)(s); + { SampleTailRecursiveEqualIfSameLog2Floor(n, Helper.Power(2, k), k, 0, s); } + SampleTailRecursive(Helper.Power(2, k))(s); + Monad.Bind(Model.Sample(Helper.Power(2, 0)), (u: nat) => SampleTailRecursive(Helper.Power(2, k), u))(s); + { RelateWithTailRecursive(k, 0, s); } + Model.Sample(Helper.Power(2, k))(s); + { SampleEqualIfSameLog2Floor(n, Helper.Power(2, k), k, s); } + Model.Sample(n)(s); + } + } + } + + // All numbers between consecutive powers of 2 behave the same as arguments to SampleTailRecursive + lemma SampleTailRecursiveEqualIfSameLog2Floor(m: nat, n: nat, k: nat, u: nat, s: Rand.Bitstream) + requires m >= 1 + requires n >= 1 + requires Helper.Power(2, k) <= m < Helper.Power(2, k + 1) + requires Helper.Power(2, k) <= n < Helper.Power(2, k + 1) + ensures SampleTailRecursive(m, u)(s) == SampleTailRecursive(n, u)(s) + { + if k == 0 { + assert m == n; + } else { + assert 1 <= m; + assert 1 <= n; + calc { + SampleTailRecursive(m, u)(s); + SampleTailRecursive(m / 2, if Rand.Head(s) then 2*u + 1 else 2*u)(Rand.Tail(s)); + { SampleTailRecursiveEqualIfSameLog2Floor(m / 2, n / 2, k - 1, if Rand.Head(s) then 2*u + 1 else 2*u, Rand.Tail(s)); } + SampleTailRecursive(n / 2, if Rand.Head(s) then 2*u + 1 else 2*u)(Rand.Tail(s)); + SampleTailRecursive(n, u)(s); + } + } + } + + // All numbers between consecutive powers of 2 behave the same as arguments to Sample + lemma SampleEqualIfSameLog2Floor(m: nat, n: nat, k: nat, s: Rand.Bitstream) + requires m >= 1 + requires n >= 1 + requires Helper.Power(2, k) <= m < Helper.Power(2, k + 1) + requires Helper.Power(2, k) <= n < Helper.Power(2, k + 1) + ensures Model.Sample(m)(s) == Model.Sample(n)(s) + { + if k == 0 { + assert m == n; + } else { + assert 1 <= m; + assert 1 <= n; + calc { + Model.Sample(m)(s); + { reveal Model.Sample(); } + Monad.Bind(Model.Sample(m / 2), Model.UnifStep)(s); + { SampleEqualIfSameLog2Floor(m / 2, n / 2, k - 1, s); } + Monad.Bind(Model.Sample(n / 2), Model.UnifStep)(s); + { reveal Model.Sample(); } + Model.Sample(n)(s); + } + } + } + + // The induction invariant for the equivalence proof (generalized version of SampleCorrespondence) + lemma RelateWithTailRecursive(l: nat, m: nat, s: Rand.Bitstream) + decreases l + ensures Monad.Bind(Model.Sample(Helper.Power(2, m)), (u: nat) => SampleTailRecursive(Helper.Power(2, l), u))(s) == Model.Sample(Helper.Power(2, m + l))(s) + { + if l == 0 { + calc { + Monad.Bind(Model.Sample(Helper.Power(2, m)), (u: nat) => SampleTailRecursive(Helper.Power(2, l), u))(s); + (var Result(u, s') := Model.Sample(Helper.Power(2, m))(s); SampleTailRecursive(1, u)(s')); + Model.Sample(Helper.Power(2, m + l))(s); + } + } else { + assert LGreaterZero: Helper.Power(2, l) >= 1 by { Helper.PowerGreater0(2, l); } + assert MGreaterZero: Helper.Power(2, m) >= 1 by { Helper.PowerGreater0(2, m); } + assert L1GreaterZero: Helper.Power(2, l - 1) >= 1 by { Helper.PowerGreater0(2, l - 1); } + calc { + Monad.Bind(Model.Sample(Helper.Power(2, m)), (u: nat) => SampleTailRecursive(Helper.Power(2, l), u))(s); + (var Result(u, s') := Model.Sample(Helper.Power(2, m))(s); SampleTailRecursive(Helper.Power(2, l), u)(s')); + { reveal LGreaterZero; } + (var Result(u, s') := Model.Sample(Helper.Power(2, m))(s); + SampleTailRecursive(Helper.Power(2, l) / 2, if Rand.Head(s') then 2 * u + 1 else 2 * u)(Rand.Tail(s'))); + { assert Helper.Power(2, l) / 2 == Helper.Power(2, l - 1); reveal L1GreaterZero; } + (var Result(u', s') := Monad.Bind(Model.Sample(Helper.Power(2, m)), Model.UnifStep)(s); + SampleTailRecursive(Helper.Power(2, l - 1), u')(s')); + { assert Helper.Power(2, m + 1) / 2 == Helper.Power(2, m); reveal Model.Sample(); } + (var Result(u', s') := Model.Sample(Helper.Power(2, m + 1))(s); + SampleTailRecursive(Helper.Power(2, l - 1), u')(s')); + Monad.Bind(Model.Sample(Helper.Power(2, m + 1)), (u: nat) => SampleTailRecursive(Helper.Power(2, l - 1), u))(s); + { RelateWithTailRecursive(l - 1, m + 1, s); } + Model.Sample(Helper.Power(2, m + l))(s); + } + } + } + +} diff --git a/src/Distributions/UniformPowerOfTwo/Implementation.dfy b/src/Distributions/UniformPowerOfTwo/Implementation.dfy index b88595af..3c117e49 100644 --- a/src/Distributions/UniformPowerOfTwo/Implementation.dfy +++ b/src/Distributions/UniformPowerOfTwo/Implementation.dfy @@ -6,6 +6,7 @@ module UniformPowerOfTwo.Implementation { import Monad import Model + import Equivalence import Interface trait {:termination false} Trait extends Interface.Trait { @@ -19,13 +20,13 @@ module UniformPowerOfTwo.Implementation { while n' > 1 invariant n' >= 1 - invariant Model.SampleTailRecursive(n)(old(s)) == Model.SampleTailRecursive(n', u)(s) + invariant Equivalence.SampleTailRecursive(n)(old(s)) == Equivalence.SampleTailRecursive(n', u)(s) { var b := CoinSample(); u := if b then 2*u + 1 else 2*u; n' := n' / 2; } - Model.SampleCorrespondence(n, old(s)); + Equivalence.SampleCorrespondence(n, old(s)); } } } diff --git a/src/Distributions/UniformPowerOfTwo/Model.dfy b/src/Distributions/UniformPowerOfTwo/Model.dfy index f68679e6..06b48d97 100644 --- a/src/Distributions/UniformPowerOfTwo/Model.dfy +++ b/src/Distributions/UniformPowerOfTwo/Model.dfy @@ -11,14 +11,6 @@ module UniformPowerOfTwo.Model { import Independence import Loops - function UnifStepHelper(m: nat): bool -> Monad.Hurd { - (b: bool) => Monad.Return(if b then 2*m + 1 else 2*m) - } - - function UnifStep(m: nat): Monad.Hurd { - Monad.Bind(Monad.Coin, UnifStepHelper(m)) - } - // Adapted from Definition 48 (see issue #79 for the reason of the modification) // The return value u is uniformly distributed between 0 <= u < 2^k where 2^k <= n < 2^(k + 1). opaque function Sample(n: nat): (h: Monad.Hurd) @@ -31,120 +23,11 @@ module UniformPowerOfTwo.Model { Monad.Bind(Sample(n/2), UnifStep) } - // A tail recursive version of Sample, closer to the imperative implementation - function SampleTailRecursive(n: nat, u: nat := 0): Monad.Hurd - requires n >= 1 - { - (s: Rand.Bitstream) => - if n == 1 then - Monad.Result(u, s) - else - SampleTailRecursive(n / 2, if Rand.Head(s) then 2*u + 1 else 2*u)(Rand.Tail(s)) - } - - // Equivalence of Sample and its tail-recursive version - lemma SampleCorrespondence(n: nat, s: Rand.Bitstream) - requires n >= 1 - ensures SampleTailRecursive(n)(s) == Sample(n)(s) - { - if n == 1 { - reveal Sample(); - assert SampleTailRecursive(n)(s) == Sample(n)(s); - } else { - var k := Helper.Log2Floor(n); - assert Helper.Power(2, k) <= n < Helper.Power(2, k + 1) by { Helper.Power2OfLog2Floor(n); } - calc { - SampleTailRecursive(n)(s); - { SampleTailRecursiveEqualIfSameLog2Floor(n, Helper.Power(2, k), k, 0, s); } - SampleTailRecursive(Helper.Power(2, k))(s); - Monad.Bind(Sample(Helper.Power(2, 0)), (u: nat) => SampleTailRecursive(Helper.Power(2, k), u))(s); - { RelateWithTailRecursive(k, 0, s); } - Sample(Helper.Power(2, k))(s); - { SampleEqualIfSameLog2Floor(n, Helper.Power(2, k), k, s); } - Sample(n)(s); - } - } - } - - // All numbers between consecutive powers of 2 behave the same as arguments to SampleTailRecursive - lemma SampleTailRecursiveEqualIfSameLog2Floor(m: nat, n: nat, k: nat, u: nat, s: Rand.Bitstream) - requires m >= 1 - requires n >= 1 - requires Helper.Power(2, k) <= m < Helper.Power(2, k + 1) - requires Helper.Power(2, k) <= n < Helper.Power(2, k + 1) - ensures SampleTailRecursive(m, u)(s) == SampleTailRecursive(n, u)(s) - { - if k == 0 { - assert m == n; - } else { - assert 1 <= m; - assert 1 <= n; - calc { - SampleTailRecursive(m, u)(s); - SampleTailRecursive(m / 2, if Rand.Head(s) then 2*u + 1 else 2*u)(Rand.Tail(s)); - { SampleTailRecursiveEqualIfSameLog2Floor(m / 2, n / 2, k - 1, if Rand.Head(s) then 2*u + 1 else 2*u, Rand.Tail(s)); } - SampleTailRecursive(n / 2, if Rand.Head(s) then 2*u + 1 else 2*u)(Rand.Tail(s)); - SampleTailRecursive(n, u)(s); - } - } - } - - // All numbers between consecutive powers of 2 behave the same as arguments to Sample - lemma SampleEqualIfSameLog2Floor(m: nat, n: nat, k: nat, s: Rand.Bitstream) - requires m >= 1 - requires n >= 1 - requires Helper.Power(2, k) <= m < Helper.Power(2, k + 1) - requires Helper.Power(2, k) <= n < Helper.Power(2, k + 1) - ensures Sample(m)(s) == Sample(n)(s) - { - if k == 0 { - assert m == n; - } else { - assert 1 <= m; - assert 1 <= n; - calc { - Sample(m)(s); - { reveal Sample(); } - Monad.Bind(Sample(m / 2), UnifStep)(s); - { SampleEqualIfSameLog2Floor(m / 2, n / 2, k - 1, s); } - Monad.Bind(Sample(n / 2), UnifStep)(s); - { reveal Sample(); } - Sample(n)(s); - } - } + function UnifStepHelper(m: nat): bool -> Monad.Hurd { + (b: bool) => Monad.Return(if b then 2*m + 1 else 2*m) } - // The induction invariant for the equivalence proof (generalized version of SampleCorrespondence) - lemma RelateWithTailRecursive(l: nat, m: nat, s: Rand.Bitstream) - decreases l - ensures Monad.Bind(Sample(Helper.Power(2, m)), (u: nat) => SampleTailRecursive(Helper.Power(2, l), u))(s) == Sample(Helper.Power(2, m + l))(s) - { - if l == 0 { - calc { - Monad.Bind(Sample(Helper.Power(2, m)), (u: nat) => SampleTailRecursive(Helper.Power(2, l), u))(s); - (var Result(u, s') := Sample(Helper.Power(2, m))(s); SampleTailRecursive(1, u)(s')); - Sample(Helper.Power(2, m + l))(s); - } - } else { - assert LGreaterZero: Helper.Power(2, l) >= 1 by { Helper.PowerGreater0(2, l); } - assert MGreaterZero: Helper.Power(2, m) >= 1 by { Helper.PowerGreater0(2, m); } - assert L1GreaterZero: Helper.Power(2, l - 1) >= 1 by { Helper.PowerGreater0(2, l - 1); } - calc { - Monad.Bind(Sample(Helper.Power(2, m)), (u: nat) => SampleTailRecursive(Helper.Power(2, l), u))(s); - (var Result(u, s') := Sample(Helper.Power(2, m))(s); SampleTailRecursive(Helper.Power(2, l), u)(s')); - { reveal LGreaterZero; } - (var Result(u, s') := Sample(Helper.Power(2, m))(s); - SampleTailRecursive(Helper.Power(2, l) / 2, if Rand.Head(s') then 2 * u + 1 else 2 * u)(Rand.Tail(s'))); - { assert Helper.Power(2, l) / 2 == Helper.Power(2, l - 1); reveal L1GreaterZero; } - (var Result(u', s') := Monad.Bind(Sample(Helper.Power(2, m)), UnifStep)(s); - SampleTailRecursive(Helper.Power(2, l - 1), u')(s')); - { assert Helper.Power(2, m + 1) / 2 == Helper.Power(2, m); reveal Sample(); } - (var Result(u', s') := Sample(Helper.Power(2, m + 1))(s); - SampleTailRecursive(Helper.Power(2, l - 1), u')(s')); - Monad.Bind(Sample(Helper.Power(2, m + 1)), (u: nat) => SampleTailRecursive(Helper.Power(2, l - 1), u))(s); - { RelateWithTailRecursive(l - 1, m + 1, s); } - Sample(Helper.Power(2, m + l))(s); - } - } + function UnifStep(m: nat): Monad.Hurd { + Monad.Bind(Monad.Coin, UnifStepHelper(m)) } }