Skip to content

Commit

Permalink
Separate modules for equivalence proofs (#108)
Browse files Browse the repository at this point in the history
Moving all the code that is referred to in `Implementation` to establish
the equivalence between the imperative and functional model into a
separate module `Equivalence`. The code in the new module should not be
referred to by any other module than `Implementation`.

By submitting this pull request, I confirm that my contribution is made
under the terms of the [MIT
license](https://github.com/dafny-lang/dafny/blob/master/LICENSE.txt).
  • Loading branch information
stefan-aws authored Oct 30, 2023
1 parent 1d7294d commit ecdc5e6
Show file tree
Hide file tree
Showing 10 changed files with 258 additions and 200 deletions.
2 changes: 1 addition & 1 deletion audit.log
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ src/Distributions/BernoulliExpNeg/Correctness.dfy(13,17): Correctness: Declarati
src/Distributions/BernoulliExpNeg/Correctness.dfy(17,17): SampleIsIndep: Declaration has explicit `{:axiom}` attribute.
src/Distributions/BernoulliExpNeg/Model.dfy(104,17): GammaLe1LoopTerminatesAlmostSurely: Declaration has explicit `{:axiom}` attribute.
src/Distributions/Coin/Interface.dfy(21,6): CoinSample: Definition has `assume {:axiom}` statement in body.
src/Distributions/Uniform/Implementation.dfy(46,6): UniformSample: Definition has `assume {:axiom}` statement in body.
src/Distributions/Uniform/Implementation.dfy(47,6): UniformSample: Definition has `assume {:axiom}` statement in body.
src/Math/Analysis/Reals.dfy(35,17): LeastUpperBoundProperty: Declaration has explicit `{:axiom}` attribute.
src/Math/Exponential.dfy(11,17): EvalOne: Declaration has explicit `{:axiom}` attribute.
src/Math/Exponential.dfy(2,26): Exp: Declaration has explicit `{:axiom}` attribute.
Expand Down
85 changes: 85 additions & 0 deletions src/Distributions/BernoulliExpNeg/Equivalence.dfy
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*******************************************************************************
* Copyright by the contributors to the Dafny Project
* SPDX-License-Identifier: MIT
*******************************************************************************/

module BernoulliExpNeg.Equivalence {
import Rationals
import Rand
import Monad
import Model
import Loops
import Bernoulli

/************
Definitions
************/

opaque ghost predicate CaseLe1LoopInvariant(gamma: Rationals.Rational, oldS: Rand.Bitstream, a: bool, k: nat, s: Rand.Bitstream)
requires 0 <= gamma.numer <= gamma.denom
{
Model.GammaLe1Loop(gamma)((true, 0))(oldS) == Model.GammaLe1Loop(gamma)((a, k))(s)
}

/*******
Lemmas
*******/

lemma GammaLe1LoopUnroll(gamma: Rationals.Rational, ak: (bool, nat), s: Rand.Bitstream)
requires 0 <= gamma.numer <= gamma.denom
requires ak.0
ensures Model.GammaLe1Loop(gamma)(ak)(s) == Monad.Bind(Model.GammaLe1LoopIter(gamma)(ak), Model.GammaLe1Loop(gamma))(s)
{
Model.GammaLe1LoopTerminatesAlmostSurely(gamma);
calc {
Model.GammaLe1Loop(gamma)(ak)(s);
{ reveal Model.GammaLe1Loop(); }
Loops.While(Model.GammaLe1LoopCondition, Model.GammaLe1LoopIter(gamma), ak)(s);
{ reveal Model.GammaLe1Loop();
Loops.WhileUnroll(Model.GammaLe1LoopCondition, Model.GammaLe1LoopIter(gamma), ak, s); }
Monad.Bind(Model.GammaLe1LoopIter(gamma)(ak), Model.GammaLe1Loop(gamma))(s);
}
}

lemma EnsureCaseLe1LoopInvariantOnEntry(gamma: Rationals.Rational, s: Rand.Bitstream)
requires 0 <= gamma.numer <= gamma.denom
ensures CaseLe1LoopInvariant(gamma, s, true, 0, s)
{
reveal CaseLe1LoopInvariant();
}

lemma EnsureCaseLe1LoopInvariantMaintained(gamma: Rationals.Rational, oldS: Rand.Bitstream, k: nat, s: Rand.Bitstream, a': bool, k': nat, s': Rand.Bitstream)
requires 0 <= gamma.numer <= gamma.denom
requires k' == k + 1
requires inv: CaseLe1LoopInvariant(gamma, oldS, true, k, s)
requires bernoulli: Monad.Result(a' , s') == Bernoulli.Model.Sample(gamma.numer, k' * gamma.denom)(s)
ensures CaseLe1LoopInvariant(gamma, oldS, a', k', s')
{
assert iter: Monad.Result((a', k'), s') == Model.GammaLe1LoopIter(gamma)((true, k))(s) by {
reveal bernoulli;
}
calc {
Model.GammaLe1Loop(gamma)((true, 0))(oldS);
{ reveal CaseLe1LoopInvariant(); reveal inv; }
Model.GammaLe1Loop(gamma)((true, k))(s);
{ reveal iter; GammaLe1LoopUnroll(gamma, (true, k), s); }
Model.GammaLe1Loop(gamma)((a', k'))(s');
}
reveal CaseLe1LoopInvariant();
}

lemma EnsureCaseLe1PostCondition(gamma: Rationals.Rational, oldS: Rand.Bitstream, k: nat, s: Rand.Bitstream, c: bool)
requires 0 <= gamma.numer <= gamma.denom
requires CaseLe1LoopInvariant(gamma, oldS, false, k, s)
requires c <==> (k % 2 == 1)
ensures Monad.Result(c, s) == Model.SampleGammaLe1(gamma)(oldS)
{
calc {
Model.GammaLe1Loop(gamma)((true, 0))(oldS);
{ reveal CaseLe1LoopInvariant(); }
Model.GammaLe1Loop(gamma)((false, k))(s);
{ reveal Model.GammaLe1Loop(); }
Monad.Result((false, k), s);
}
}
}
56 changes: 5 additions & 51 deletions src/Distributions/BernoulliExpNeg/Implementation.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ module BernoulliExpNeg.Implementation {
import Interface
import Model
import Bernoulli
import Equivalence

trait {:termination false} Trait extends Interface.Trait {

Expand Down Expand Up @@ -47,68 +48,21 @@ module BernoulliExpNeg.Implementation {
{
var k: nat := 0;
var a := true;
EnsureCaseLe1LoopInvariantOnEntry(gamma, s);
Equivalence.EnsureCaseLe1LoopInvariantOnEntry(gamma, s);
while a
decreases *
invariant CaseLe1LoopInvariant(gamma, old(s), a, k, s)
invariant Equivalence.CaseLe1LoopInvariant(gamma, old(s), a, k, s)
{
ghost var prevK: nat := k;
ghost var prevS := s;
k := k + 1;
Helper.MulMonotonic(1, gamma.denom, k, gamma.denom);
a := BernoulliSample(Rationals.Rational(gamma.numer, k * gamma.denom));
EnsureCaseLe1LoopInvariantMaintained(gamma, old(s), prevK, prevS, a, k, s);
Equivalence.EnsureCaseLe1LoopInvariantMaintained(gamma, old(s), prevK, prevS, a, k, s);
}
c := k % 2 == 1;
EnsureCaseLe1PostCondition(gamma, old(s), k, s, c);
Equivalence.EnsureCaseLe1PostCondition(gamma, old(s), k, s, c);
}
}

opaque ghost predicate CaseLe1LoopInvariant(gamma: Rationals.Rational, oldS: Rand.Bitstream, a: bool, k: nat, s: Rand.Bitstream)
requires 0 <= gamma.numer <= gamma.denom
{
Model.GammaLe1Loop(gamma)((true, 0))(oldS) == Model.GammaLe1Loop(gamma)((a, k))(s)
}

lemma EnsureCaseLe1LoopInvariantOnEntry(gamma: Rationals.Rational, s: Rand.Bitstream)
requires 0 <= gamma.numer <= gamma.denom
ensures CaseLe1LoopInvariant(gamma, s, true, 0, s)
{
reveal CaseLe1LoopInvariant();
}

lemma EnsureCaseLe1LoopInvariantMaintained(gamma: Rationals.Rational, oldS: Rand.Bitstream, k: nat, s: Rand.Bitstream, a': bool, k': nat, s': Rand.Bitstream)
requires 0 <= gamma.numer <= gamma.denom
requires k' == k + 1
requires inv: CaseLe1LoopInvariant(gamma, oldS, true, k, s)
requires bernoulli: Monad.Result(a' , s') == Bernoulli.Model.Sample(gamma.numer, k' * gamma.denom)(s)
ensures CaseLe1LoopInvariant(gamma, oldS, a', k', s')
{
assert iter: Monad.Result((a', k'), s') == Model.GammaLe1LoopIter(gamma)((true, k))(s) by {
reveal bernoulli;
}
calc {
Model.GammaLe1Loop(gamma)((true, 0))(oldS);
{ reveal CaseLe1LoopInvariant(); reveal inv; }
Model.GammaLe1Loop(gamma)((true, k))(s);
{ reveal iter; Model.GammaLe1LoopUnroll(gamma, (true, k), s); }
Model.GammaLe1Loop(gamma)((a', k'))(s');
}
reveal CaseLe1LoopInvariant();
}

lemma EnsureCaseLe1PostCondition(gamma: Rationals.Rational, oldS: Rand.Bitstream, k: nat, s: Rand.Bitstream, c: bool)
requires 0 <= gamma.numer <= gamma.denom
requires CaseLe1LoopInvariant(gamma, oldS, false, k, s)
requires c <==> (k % 2 == 1)
ensures Monad.Result(c, s) == Model.SampleGammaLe1(gamma)(oldS)
{
calc {
Model.GammaLe1Loop(gamma)((true, 0))(oldS);
{ reveal CaseLe1LoopInvariant(); }
Model.GammaLe1Loop(gamma)((false, k))(s);
{ reveal Model.GammaLe1Loop(); }
Monad.Result((false, k), s);
}
}
}
16 changes: 1 addition & 15 deletions src/Distributions/BernoulliExpNeg/Model.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,5 @@ module BernoulliExpNeg.Model {
requires 0 <= gamma.numer <= gamma.denom
ensures Loops.WhileTerminatesAlmostSurely(GammaLe1LoopCondition, GammaLe1LoopIter(gamma))

lemma GammaLe1LoopUnroll(gamma: Rationals.Rational, ak: (bool, nat), s: Rand.Bitstream)
requires 0 <= gamma.numer <= gamma.denom
requires ak.0
ensures GammaLe1Loop(gamma)(ak)(s) == Monad.Bind(GammaLe1LoopIter(gamma)(ak), GammaLe1Loop(gamma))(s)
{
GammaLe1LoopTerminatesAlmostSurely(gamma);
calc {
GammaLe1Loop(gamma)(ak)(s);
{ reveal GammaLe1Loop(); }
Loops.While(GammaLe1LoopCondition, GammaLe1LoopIter(gamma), ak)(s);
{ reveal GammaLe1Loop();
Loops.WhileUnroll(GammaLe1LoopCondition, GammaLe1LoopIter(gamma), ak, s); }
Monad.Bind(GammaLe1LoopIter(gamma)(ak), GammaLe1Loop(gamma))(s);
}
}

}
20 changes: 20 additions & 0 deletions src/Distributions/Uniform/Equivalence.dfy
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*******************************************************************************
* Copyright by the contributors to the Dafny Project
* SPDX-License-Identifier: MIT
*******************************************************************************/

module Uniform.Equivalence {
import Rand
import Model
import Monad
import Loops

lemma SampleUnroll(n: nat, s: Rand.Bitstream)
requires n > 0
ensures Model.Sample(n)(s) == Monad.Bind(Model.Proposal(n), (x: nat) => if Model.Accept(n)(x) then Monad.Return(x) else Model.Sample(n))(s)
{
Model.SampleTerminates(n);
reveal Model.Sample();
Loops.UntilUnroll(Model.Proposal(n), Model.Accept(n), s);
}
}
3 changes: 2 additions & 1 deletion src/Distributions/Uniform/Implementation.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ module Uniform.Implementation {
import UniformPowerOfTwo
import Model
import Interface
import Equivalence

trait {:termination false} TraitFoundational extends Interface.Trait {
method UniformSample(n: nat) returns (u: nat)
Expand All @@ -24,7 +25,7 @@ module Uniform.Implementation {
invariant Model.Sample(n)(old(s)) == Model.Sample(n)(prevS)
invariant Monad.Result(u, s) == Model.Proposal(n)(prevS)
{
Model.SampleUnroll(n, prevS);
Equivalence.SampleUnroll(n, prevS);
prevS := s;
u := UniformPowerOfTwoSample(2 * n);
}
Expand Down
9 changes: 0 additions & 9 deletions src/Distributions/Uniform/Model.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,4 @@ module Uniform.Model {
Loops.EnsureUntilTerminates(Proposal(n), Accept(n));
}
}

lemma SampleUnroll(n: nat, s: Rand.Bitstream)
requires n > 0
ensures Sample(n)(s) == Monad.Bind(Proposal(n), (x: nat) => if Accept(n)(x) then Monad.Return(x) else Sample(n))(s)
{
SampleTerminates(n);
reveal Sample();
Loops.UntilUnroll(Proposal(n), Accept(n), s);
}
}
137 changes: 137 additions & 0 deletions src/Distributions/UniformPowerOfTwo/Equivalence.dfy
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*******************************************************************************
* Copyright by the contributors to the Dafny Project
* SPDX-License-Identifier: MIT
*******************************************************************************/

module UniformPowerOfTwo.Equivalence {
import Rand
import Monad
import Helper
import Model

/************
Definitions
************/

// A tail recursive version of Sample, closer to the imperative implementation
function SampleTailRecursive(n: nat, u: nat := 0): Monad.Hurd<nat>
requires n >= 1
{
(s: Rand.Bitstream) =>
if n == 1 then
Monad.Result(u, s)
else
SampleTailRecursive(n / 2, if Rand.Head(s) then 2*u + 1 else 2*u)(Rand.Tail(s))
}

/*******
Lemmas
*******/

// Equivalence of Sample and its tail-recursive version
lemma SampleCorrespondence(n: nat, s: Rand.Bitstream)
requires n >= 1
ensures SampleTailRecursive(n)(s) == Model.Sample(n)(s)
{
if n == 1 {
reveal Model.Sample();
assert SampleTailRecursive(n)(s) == Model.Sample(n)(s);
} else {
var k := Helper.Log2Floor(n);
assert Helper.Power(2, k) <= n < Helper.Power(2, k + 1) by { Helper.Power2OfLog2Floor(n); }
calc {
SampleTailRecursive(n)(s);
{ SampleTailRecursiveEqualIfSameLog2Floor(n, Helper.Power(2, k), k, 0, s); }
SampleTailRecursive(Helper.Power(2, k))(s);
Monad.Bind(Model.Sample(Helper.Power(2, 0)), (u: nat) => SampleTailRecursive(Helper.Power(2, k), u))(s);
{ RelateWithTailRecursive(k, 0, s); }
Model.Sample(Helper.Power(2, k))(s);
{ SampleEqualIfSameLog2Floor(n, Helper.Power(2, k), k, s); }
Model.Sample(n)(s);
}
}
}

// All numbers between consecutive powers of 2 behave the same as arguments to SampleTailRecursive
lemma SampleTailRecursiveEqualIfSameLog2Floor(m: nat, n: nat, k: nat, u: nat, s: Rand.Bitstream)
requires m >= 1
requires n >= 1
requires Helper.Power(2, k) <= m < Helper.Power(2, k + 1)
requires Helper.Power(2, k) <= n < Helper.Power(2, k + 1)
ensures SampleTailRecursive(m, u)(s) == SampleTailRecursive(n, u)(s)
{
if k == 0 {
assert m == n;
} else {
assert 1 <= m;
assert 1 <= n;
calc {
SampleTailRecursive(m, u)(s);
SampleTailRecursive(m / 2, if Rand.Head(s) then 2*u + 1 else 2*u)(Rand.Tail(s));
{ SampleTailRecursiveEqualIfSameLog2Floor(m / 2, n / 2, k - 1, if Rand.Head(s) then 2*u + 1 else 2*u, Rand.Tail(s)); }
SampleTailRecursive(n / 2, if Rand.Head(s) then 2*u + 1 else 2*u)(Rand.Tail(s));
SampleTailRecursive(n, u)(s);
}
}
}

// All numbers between consecutive powers of 2 behave the same as arguments to Sample
lemma SampleEqualIfSameLog2Floor(m: nat, n: nat, k: nat, s: Rand.Bitstream)
requires m >= 1
requires n >= 1
requires Helper.Power(2, k) <= m < Helper.Power(2, k + 1)
requires Helper.Power(2, k) <= n < Helper.Power(2, k + 1)
ensures Model.Sample(m)(s) == Model.Sample(n)(s)
{
if k == 0 {
assert m == n;
} else {
assert 1 <= m;
assert 1 <= n;
calc {
Model.Sample(m)(s);
{ reveal Model.Sample(); }
Monad.Bind(Model.Sample(m / 2), Model.UnifStep)(s);
{ SampleEqualIfSameLog2Floor(m / 2, n / 2, k - 1, s); }
Monad.Bind(Model.Sample(n / 2), Model.UnifStep)(s);
{ reveal Model.Sample(); }
Model.Sample(n)(s);
}
}
}

// The induction invariant for the equivalence proof (generalized version of SampleCorrespondence)
lemma RelateWithTailRecursive(l: nat, m: nat, s: Rand.Bitstream)
decreases l
ensures Monad.Bind(Model.Sample(Helper.Power(2, m)), (u: nat) => SampleTailRecursive(Helper.Power(2, l), u))(s) == Model.Sample(Helper.Power(2, m + l))(s)
{
if l == 0 {
calc {
Monad.Bind(Model.Sample(Helper.Power(2, m)), (u: nat) => SampleTailRecursive(Helper.Power(2, l), u))(s);
(var Result(u, s') := Model.Sample(Helper.Power(2, m))(s); SampleTailRecursive(1, u)(s'));
Model.Sample(Helper.Power(2, m + l))(s);
}
} else {
assert LGreaterZero: Helper.Power(2, l) >= 1 by { Helper.PowerGreater0(2, l); }
assert MGreaterZero: Helper.Power(2, m) >= 1 by { Helper.PowerGreater0(2, m); }
assert L1GreaterZero: Helper.Power(2, l - 1) >= 1 by { Helper.PowerGreater0(2, l - 1); }
calc {
Monad.Bind(Model.Sample(Helper.Power(2, m)), (u: nat) => SampleTailRecursive(Helper.Power(2, l), u))(s);
(var Result(u, s') := Model.Sample(Helper.Power(2, m))(s); SampleTailRecursive(Helper.Power(2, l), u)(s'));
{ reveal LGreaterZero; }
(var Result(u, s') := Model.Sample(Helper.Power(2, m))(s);
SampleTailRecursive(Helper.Power(2, l) / 2, if Rand.Head(s') then 2 * u + 1 else 2 * u)(Rand.Tail(s')));
{ assert Helper.Power(2, l) / 2 == Helper.Power(2, l - 1); reveal L1GreaterZero; }
(var Result(u', s') := Monad.Bind(Model.Sample(Helper.Power(2, m)), Model.UnifStep)(s);
SampleTailRecursive(Helper.Power(2, l - 1), u')(s'));
{ assert Helper.Power(2, m + 1) / 2 == Helper.Power(2, m); reveal Model.Sample(); }
(var Result(u', s') := Model.Sample(Helper.Power(2, m + 1))(s);
SampleTailRecursive(Helper.Power(2, l - 1), u')(s'));
Monad.Bind(Model.Sample(Helper.Power(2, m + 1)), (u: nat) => SampleTailRecursive(Helper.Power(2, l - 1), u))(s);
{ RelateWithTailRecursive(l - 1, m + 1, s); }
Model.Sample(Helper.Power(2, m + l))(s);
}
}
}

}
Loading

0 comments on commit ecdc5e6

Please sign in to comment.