Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-aws committed Oct 31, 2023
1 parent 4e62653 commit 0cf989d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
34 changes: 32 additions & 2 deletions src/Distributions/UniformPowerOfTwo/Correctness.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,40 @@ module UniformPowerOfTwo.Correctness {
ghost predicate UnifIsCorrect(n: nat, k: nat, m: nat)
requires Helper.Power(2, k) <= n < Helper.Power(2, k + 1)
{
SampleReturnsResult(n);
Rand.prob(iset s | Model.Sample(n)(s).value == m) == if m < Helper.Power(2, k) then 1.0 / (Helper.Power(2, k) as real) else 0.0
}

function SampleRest(n: nat): Rand.Bitstream -> Rand.Bitstream
requires n >= 1
{
SampleReturnsResult(n);
(s: Rand.Bitstream) => Model.Sample(n)(s).rest
}

/*******
Lemmas
*******/

lemma SampleReturnsResult(n: nat)
requires n >= 1
ensures forall s :: Model.Sample(n)(s).Result? // always terminates, not just almost surely
{
reveal Model.Sample();
}

// Correctness Theorem for Model.Sample.
// In contrast to UnifCorrectness, this lemma does not follow
// the thesis, but models PROB_BERN_UNIF of the HOL implementation.
lemma UnifCorrectness2(n: nat, m: nat)
requires n >= 1
ensures forall s :: Model.Sample(n)(s).Result?
ensures
var e := iset s | Model.Sample(n)(s).value == m;
&& e in Rand.eventSpace
&& Rand.prob(e) == if m < Helper.Power(2, Helper.Log2Floor(n)) then 1.0 / (Helper.Power(2, Helper.Log2Floor(n)) as real) else 0.0
{
SampleReturnsResult(n);
var e := iset s | Model.Sample(n)(s).value == m;
var k := Helper.Log2Floor(n);

Expand All @@ -65,11 +76,13 @@ module UniformPowerOfTwo.Correctness {
lemma UnifCorrectness2Inequality(n: nat, m: nat)
requires n >= 1
requires m <= Helper.Power(2, Helper.Log2Floor(n))
ensures forall s :: Model.Sample(n)(s).Result?
ensures
var e := iset s | Model.Sample(n)(s).value < m;
&& e in Rand.eventSpace
&& Rand.prob(e) == (m as real) / (Helper.Power(2, Helper.Log2Floor(n)) as real)
{
SampleReturnsResult(n);
var e := iset s | Model.Sample(n)(s).value < m;

if m == 0 {
Expand Down Expand Up @@ -112,6 +125,7 @@ module UniformPowerOfTwo.Correctness {
requires Helper.Power(2, k) <= n < Helper.Power(2, k + 1)
ensures forall m: nat :: UnifIsCorrect(n, k, m)
{
SampleReturnsResult(n);
forall m: nat ensures UnifIsCorrect(n, k, m) {
assert n >= 1 by { Helper.PowerGreater0(2, k); }
if k == 0 {
Expand All @@ -128,6 +142,7 @@ module UniformPowerOfTwo.Correctness {
assert RecursiveCorrect: UnifIsCorrect(n / 2, k - 1, m / 2) by {
UnifCorrectness(n / 2, k - 1);
}
SampleReturnsResult(n / 2);
if m < Helper.Power(2, k) {
calc {
Rand.prob(iset s | Model.Sample(n)(s).value == m);
Expand Down Expand Up @@ -197,7 +212,9 @@ module UniformPowerOfTwo.Correctness {
SampleIsIndep(n);
Monad.IsIndepImpliesMeasurableNat(Model.Sample(n));
}
assert Measures.PreImage(f, e) == preimage';
assert Measures.PreImage(f, e) == preimage' by {
assume {:axiom} false; // TODO
}
}
}
if n == 1 {
Expand All @@ -210,6 +227,8 @@ module UniformPowerOfTwo.Correctness {
}
assert Measures.IsMeasurePreserving(Rand.eventSpace, Rand.prob, Rand.eventSpace, Rand.prob, f);
} else {
SampleReturnsResult(n);
SampleReturnsResult(n / 2);
var g := SampleRest(n / 2);
forall e | e in Rand.eventSpace ensures Rand.prob(Measures.PreImage(f, e)) == Rand.prob(e) {
var e' := (iset s | Rand.Tail(s) in e);
Expand Down Expand Up @@ -256,8 +275,12 @@ module UniformPowerOfTwo.Correctness {

lemma SampleTailDecompose(n: nat, s: Rand.Bitstream)
requires n >= 2
ensures forall s :: Model.Sample(n)(s).Result?
ensures forall s :: Model.Sample(n / 2)(s).Result?
ensures Model.Sample(n)(s).rest == Rand.Tail(Model.Sample(n / 2)(s).rest)
{
SampleReturnsResult(n);
SampleReturnsResult(n / 2);
var Result(a, s') := Model.Sample(n / 2)(s);
var Result(b, s'') := Monad.Coin()(s');
calc {
Expand All @@ -281,11 +304,14 @@ module UniformPowerOfTwo.Correctness {

lemma SampleSetEquality(n: nat, m: nat)
requires n >= 2
ensures forall s :: Model.Sample(n / 2)(s).Result?
ensures
var bOf := (s: Rand.Bitstream) => Monad.Coin()(Model.Sample(n / 2)(s).rest).value;
var aOf := (s: Rand.Bitstream) => Model.Sample(n / 2)(s).value;
(iset s | Model.Sample(n)(s).value == m) == (iset s | 2*aOf(s) + Helper.boolToNat(bOf(s)) == m)
(iset s | SampleReturnsResult(n); Model.Sample(n)(s).value == m) == (iset s | 2*aOf(s) + Helper.boolToNat(bOf(s)) == m)
{
SampleReturnsResult(n);
SampleReturnsResult(n / 2);
var bOf := (s: Rand.Bitstream) => Monad.Coin()(Model.Sample(n / 2)(s).rest).value;
var aOf := (s: Rand.Bitstream) => Model.Sample(n / 2)(s).value;
forall s ensures Model.Sample(n)(s).value == m <==> (2 * aOf(s) + Helper.boolToNat(bOf(s)) == m) {
Expand All @@ -309,8 +335,12 @@ module UniformPowerOfTwo.Correctness {

lemma SampleRecursiveHalf(n: nat, m: nat)
requires n >= 2
ensures forall s :: Model.Sample(n)(s).Result?
ensures forall s :: Model.Sample(n / 2)(s).Result?
ensures Rand.prob(iset s | Model.Sample(n)(s).value == m) == Rand.prob(iset s | Model.Sample(n / 2)(s).value == m / 2) / 2.0
{
SampleReturnsResult(n);
SampleReturnsResult(n / 2);
var aOf: Rand.Bitstream -> nat := (s: Rand.Bitstream) => Model.Sample(n / 2)(s).value;
var bOf: Rand.Bitstream -> bool := (s: Rand.Bitstream) => Monad.Coin()(Model.Sample(n / 2)(s).rest).value;
var A: iset<nat> := (iset x: nat | x == m / 2);
Expand Down
2 changes: 1 addition & 1 deletion src/Distributions/UniformPowerOfTwo/Model.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ module UniformPowerOfTwo.Model {
// The return value u is uniformly distributed between 0 <= u < 2^k where 2^k <= n < 2^(k + 1).
opaque function Sample(n: nat): (h: Monad.Hurd<nat>)
requires n >= 1
//ensures forall s :: Sample(n)(s).Result? // always terminates, not just almost surely
{
if n == 1 then
Monad.Return(0)
Expand All @@ -22,6 +21,7 @@ module UniformPowerOfTwo.Model {
}

function UnifStepHelper(m: nat): bool -> Monad.Hurd<nat> {
assume {:axiom} false; // TODO
(b: bool) => Monad.Return(if b then 2*m + 1 else 2*m)
}

Expand Down

0 comments on commit 0cf989d

Please sign in to comment.