-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Separate modules for equivalence proofs (#108)
Moving all the code that is referred to in `Implementation` to establish the equivalence between the imperative and functional model into a separate module `Equivalence`. The code in the new module should not be referred to by any other module than `Implementation`. By submitting this pull request, I confirm that my contribution is made under the terms of the [MIT license](https://github.com/dafny-lang/dafny/blob/master/LICENSE.txt).
- Loading branch information
1 parent
1d7294d
commit ecdc5e6
Showing
10 changed files
with
258 additions
and
200 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
/******************************************************************************* | ||
* Copyright by the contributors to the Dafny Project | ||
* SPDX-License-Identifier: MIT | ||
*******************************************************************************/ | ||
|
||
module BernoulliExpNeg.Equivalence { | ||
import Rationals | ||
import Rand | ||
import Monad | ||
import Model | ||
import Loops | ||
import Bernoulli | ||
|
||
/************ | ||
Definitions | ||
************/ | ||
|
||
opaque ghost predicate CaseLe1LoopInvariant(gamma: Rationals.Rational, oldS: Rand.Bitstream, a: bool, k: nat, s: Rand.Bitstream) | ||
requires 0 <= gamma.numer <= gamma.denom | ||
{ | ||
Model.GammaLe1Loop(gamma)((true, 0))(oldS) == Model.GammaLe1Loop(gamma)((a, k))(s) | ||
} | ||
|
||
/******* | ||
Lemmas | ||
*******/ | ||
|
||
lemma GammaLe1LoopUnroll(gamma: Rationals.Rational, ak: (bool, nat), s: Rand.Bitstream) | ||
requires 0 <= gamma.numer <= gamma.denom | ||
requires ak.0 | ||
ensures Model.GammaLe1Loop(gamma)(ak)(s) == Monad.Bind(Model.GammaLe1LoopIter(gamma)(ak), Model.GammaLe1Loop(gamma))(s) | ||
{ | ||
Model.GammaLe1LoopTerminatesAlmostSurely(gamma); | ||
calc { | ||
Model.GammaLe1Loop(gamma)(ak)(s); | ||
{ reveal Model.GammaLe1Loop(); } | ||
Loops.While(Model.GammaLe1LoopCondition, Model.GammaLe1LoopIter(gamma), ak)(s); | ||
{ reveal Model.GammaLe1Loop(); | ||
Loops.WhileUnroll(Model.GammaLe1LoopCondition, Model.GammaLe1LoopIter(gamma), ak, s); } | ||
Monad.Bind(Model.GammaLe1LoopIter(gamma)(ak), Model.GammaLe1Loop(gamma))(s); | ||
} | ||
} | ||
|
||
lemma EnsureCaseLe1LoopInvariantOnEntry(gamma: Rationals.Rational, s: Rand.Bitstream) | ||
requires 0 <= gamma.numer <= gamma.denom | ||
ensures CaseLe1LoopInvariant(gamma, s, true, 0, s) | ||
{ | ||
reveal CaseLe1LoopInvariant(); | ||
} | ||
|
||
lemma EnsureCaseLe1LoopInvariantMaintained(gamma: Rationals.Rational, oldS: Rand.Bitstream, k: nat, s: Rand.Bitstream, a': bool, k': nat, s': Rand.Bitstream) | ||
requires 0 <= gamma.numer <= gamma.denom | ||
requires k' == k + 1 | ||
requires inv: CaseLe1LoopInvariant(gamma, oldS, true, k, s) | ||
requires bernoulli: Monad.Result(a' , s') == Bernoulli.Model.Sample(gamma.numer, k' * gamma.denom)(s) | ||
ensures CaseLe1LoopInvariant(gamma, oldS, a', k', s') | ||
{ | ||
assert iter: Monad.Result((a', k'), s') == Model.GammaLe1LoopIter(gamma)((true, k))(s) by { | ||
reveal bernoulli; | ||
} | ||
calc { | ||
Model.GammaLe1Loop(gamma)((true, 0))(oldS); | ||
{ reveal CaseLe1LoopInvariant(); reveal inv; } | ||
Model.GammaLe1Loop(gamma)((true, k))(s); | ||
{ reveal iter; GammaLe1LoopUnroll(gamma, (true, k), s); } | ||
Model.GammaLe1Loop(gamma)((a', k'))(s'); | ||
} | ||
reveal CaseLe1LoopInvariant(); | ||
} | ||
|
||
lemma EnsureCaseLe1PostCondition(gamma: Rationals.Rational, oldS: Rand.Bitstream, k: nat, s: Rand.Bitstream, c: bool) | ||
requires 0 <= gamma.numer <= gamma.denom | ||
requires CaseLe1LoopInvariant(gamma, oldS, false, k, s) | ||
requires c <==> (k % 2 == 1) | ||
ensures Monad.Result(c, s) == Model.SampleGammaLe1(gamma)(oldS) | ||
{ | ||
calc { | ||
Model.GammaLe1Loop(gamma)((true, 0))(oldS); | ||
{ reveal CaseLe1LoopInvariant(); } | ||
Model.GammaLe1Loop(gamma)((false, k))(s); | ||
{ reveal Model.GammaLe1Loop(); } | ||
Monad.Result((false, k), s); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
/******************************************************************************* | ||
* Copyright by the contributors to the Dafny Project | ||
* SPDX-License-Identifier: MIT | ||
*******************************************************************************/ | ||
|
||
module Uniform.Equivalence { | ||
import Rand | ||
import Model | ||
import Monad | ||
import Loops | ||
|
||
lemma SampleUnroll(n: nat, s: Rand.Bitstream) | ||
requires n > 0 | ||
ensures Model.Sample(n)(s) == Monad.Bind(Model.Proposal(n), (x: nat) => if Model.Accept(n)(x) then Monad.Return(x) else Model.Sample(n))(s) | ||
{ | ||
Model.SampleTerminates(n); | ||
reveal Model.Sample(); | ||
Loops.UntilUnroll(Model.Proposal(n), Model.Accept(n), s); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
/******************************************************************************* | ||
* Copyright by the contributors to the Dafny Project | ||
* SPDX-License-Identifier: MIT | ||
*******************************************************************************/ | ||
|
||
module UniformPowerOfTwo.Equivalence { | ||
import Rand | ||
import Monad | ||
import Helper | ||
import Model | ||
|
||
/************ | ||
Definitions | ||
************/ | ||
|
||
// A tail recursive version of Sample, closer to the imperative implementation | ||
function SampleTailRecursive(n: nat, u: nat := 0): Monad.Hurd<nat> | ||
requires n >= 1 | ||
{ | ||
(s: Rand.Bitstream) => | ||
if n == 1 then | ||
Monad.Result(u, s) | ||
else | ||
SampleTailRecursive(n / 2, if Rand.Head(s) then 2*u + 1 else 2*u)(Rand.Tail(s)) | ||
} | ||
|
||
/******* | ||
Lemmas | ||
*******/ | ||
|
||
// Equivalence of Sample and its tail-recursive version | ||
lemma SampleCorrespondence(n: nat, s: Rand.Bitstream) | ||
requires n >= 1 | ||
ensures SampleTailRecursive(n)(s) == Model.Sample(n)(s) | ||
{ | ||
if n == 1 { | ||
reveal Model.Sample(); | ||
assert SampleTailRecursive(n)(s) == Model.Sample(n)(s); | ||
} else { | ||
var k := Helper.Log2Floor(n); | ||
assert Helper.Power(2, k) <= n < Helper.Power(2, k + 1) by { Helper.Power2OfLog2Floor(n); } | ||
calc { | ||
SampleTailRecursive(n)(s); | ||
{ SampleTailRecursiveEqualIfSameLog2Floor(n, Helper.Power(2, k), k, 0, s); } | ||
SampleTailRecursive(Helper.Power(2, k))(s); | ||
Monad.Bind(Model.Sample(Helper.Power(2, 0)), (u: nat) => SampleTailRecursive(Helper.Power(2, k), u))(s); | ||
{ RelateWithTailRecursive(k, 0, s); } | ||
Model.Sample(Helper.Power(2, k))(s); | ||
{ SampleEqualIfSameLog2Floor(n, Helper.Power(2, k), k, s); } | ||
Model.Sample(n)(s); | ||
} | ||
} | ||
} | ||
|
||
// All numbers between consecutive powers of 2 behave the same as arguments to SampleTailRecursive | ||
lemma SampleTailRecursiveEqualIfSameLog2Floor(m: nat, n: nat, k: nat, u: nat, s: Rand.Bitstream) | ||
requires m >= 1 | ||
requires n >= 1 | ||
requires Helper.Power(2, k) <= m < Helper.Power(2, k + 1) | ||
requires Helper.Power(2, k) <= n < Helper.Power(2, k + 1) | ||
ensures SampleTailRecursive(m, u)(s) == SampleTailRecursive(n, u)(s) | ||
{ | ||
if k == 0 { | ||
assert m == n; | ||
} else { | ||
assert 1 <= m; | ||
assert 1 <= n; | ||
calc { | ||
SampleTailRecursive(m, u)(s); | ||
SampleTailRecursive(m / 2, if Rand.Head(s) then 2*u + 1 else 2*u)(Rand.Tail(s)); | ||
{ SampleTailRecursiveEqualIfSameLog2Floor(m / 2, n / 2, k - 1, if Rand.Head(s) then 2*u + 1 else 2*u, Rand.Tail(s)); } | ||
SampleTailRecursive(n / 2, if Rand.Head(s) then 2*u + 1 else 2*u)(Rand.Tail(s)); | ||
SampleTailRecursive(n, u)(s); | ||
} | ||
} | ||
} | ||
|
||
// All numbers between consecutive powers of 2 behave the same as arguments to Sample | ||
lemma SampleEqualIfSameLog2Floor(m: nat, n: nat, k: nat, s: Rand.Bitstream) | ||
requires m >= 1 | ||
requires n >= 1 | ||
requires Helper.Power(2, k) <= m < Helper.Power(2, k + 1) | ||
requires Helper.Power(2, k) <= n < Helper.Power(2, k + 1) | ||
ensures Model.Sample(m)(s) == Model.Sample(n)(s) | ||
{ | ||
if k == 0 { | ||
assert m == n; | ||
} else { | ||
assert 1 <= m; | ||
assert 1 <= n; | ||
calc { | ||
Model.Sample(m)(s); | ||
{ reveal Model.Sample(); } | ||
Monad.Bind(Model.Sample(m / 2), Model.UnifStep)(s); | ||
{ SampleEqualIfSameLog2Floor(m / 2, n / 2, k - 1, s); } | ||
Monad.Bind(Model.Sample(n / 2), Model.UnifStep)(s); | ||
{ reveal Model.Sample(); } | ||
Model.Sample(n)(s); | ||
} | ||
} | ||
} | ||
|
||
// The induction invariant for the equivalence proof (generalized version of SampleCorrespondence) | ||
lemma RelateWithTailRecursive(l: nat, m: nat, s: Rand.Bitstream) | ||
decreases l | ||
ensures Monad.Bind(Model.Sample(Helper.Power(2, m)), (u: nat) => SampleTailRecursive(Helper.Power(2, l), u))(s) == Model.Sample(Helper.Power(2, m + l))(s) | ||
{ | ||
if l == 0 { | ||
calc { | ||
Monad.Bind(Model.Sample(Helper.Power(2, m)), (u: nat) => SampleTailRecursive(Helper.Power(2, l), u))(s); | ||
(var Result(u, s') := Model.Sample(Helper.Power(2, m))(s); SampleTailRecursive(1, u)(s')); | ||
Model.Sample(Helper.Power(2, m + l))(s); | ||
} | ||
} else { | ||
assert LGreaterZero: Helper.Power(2, l) >= 1 by { Helper.PowerGreater0(2, l); } | ||
assert MGreaterZero: Helper.Power(2, m) >= 1 by { Helper.PowerGreater0(2, m); } | ||
assert L1GreaterZero: Helper.Power(2, l - 1) >= 1 by { Helper.PowerGreater0(2, l - 1); } | ||
calc { | ||
Monad.Bind(Model.Sample(Helper.Power(2, m)), (u: nat) => SampleTailRecursive(Helper.Power(2, l), u))(s); | ||
(var Result(u, s') := Model.Sample(Helper.Power(2, m))(s); SampleTailRecursive(Helper.Power(2, l), u)(s')); | ||
{ reveal LGreaterZero; } | ||
(var Result(u, s') := Model.Sample(Helper.Power(2, m))(s); | ||
SampleTailRecursive(Helper.Power(2, l) / 2, if Rand.Head(s') then 2 * u + 1 else 2 * u)(Rand.Tail(s'))); | ||
{ assert Helper.Power(2, l) / 2 == Helper.Power(2, l - 1); reveal L1GreaterZero; } | ||
(var Result(u', s') := Monad.Bind(Model.Sample(Helper.Power(2, m)), Model.UnifStep)(s); | ||
SampleTailRecursive(Helper.Power(2, l - 1), u')(s')); | ||
{ assert Helper.Power(2, m + 1) / 2 == Helper.Power(2, m); reveal Model.Sample(); } | ||
(var Result(u', s') := Model.Sample(Helper.Power(2, m + 1))(s); | ||
SampleTailRecursive(Helper.Power(2, l - 1), u')(s')); | ||
Monad.Bind(Model.Sample(Helper.Power(2, m + 1)), (u: nat) => SampleTailRecursive(Helper.Power(2, l - 1), u))(s); | ||
{ RelateWithTailRecursive(l - 1, m + 1, s); } | ||
Model.Sample(Helper.Power(2, m + l))(s); | ||
} | ||
} | ||
} | ||
|
||
} |
Oops, something went wrong.