Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to Byte sampling #63

Merged
merged 7 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
439 changes: 439 additions & 0 deletions SampCert/Foundations/UniformByte.lean

Large diffs are not rendered by default.

129 changes: 23 additions & 106 deletions SampCert/Foundations/UniformP2.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Mathlib.Data.Nat.Log
import SampCert.Util.Util
import SampCert.Foundations.Monad
import SampCert.Foundations.Auto
import SampCert.Foundations.UniformByte

/-!
# ``probUniformP2`` Properties
Expand All @@ -17,105 +18,30 @@ This file contains lemmas about ``probUniformP2``, a ``SLang`` sampler for the
uniform distribution on spaces whose size is a power of two.
-/


open Classical Nat PMF

namespace SLang

@[simp]
lemma sum_indicator_finrange_gen (n : Nat) (x : Nat) :
(x < n → (∑' (i : Fin n), @ite ENNReal (x = ↑i) (propDecidable (x = ↑i)) 1 0) = (1 : ENNReal))
∧ (x >= n → (∑' (i : Fin n), @ite ENNReal (x = ↑i) (propDecidable (x = ↑i)) 1 0) = (0 : ENNReal)) := by
revert x
induction n
. intro x
simp
. rename_i n IH
intro x
constructor
. intro cond
have OR : x = n ∨ x < n := by exact Order.lt_succ_iff_eq_or_lt.mp cond
cases OR
. rename_i cond'
have IH' := IH x
cases IH'
rename_i left right
have cond'' : x ≥ n := by exact Nat.le_of_eq (id cond'.symm)
have right' := right cond''
rw [tsum_fintype] at *
rw [Fin.sum_univ_castSucc]
simp [right']
simp [cond']
. rename_i cond'
have IH' := IH x
cases IH'
rename_i left right
have left' := left cond'
rw [tsum_fintype] at *
rw [Fin.sum_univ_castSucc]
simp [left']
have neq : x ≠ n := by exact Nat.ne_of_lt cond'
simp [neq]
. intro cond
have succ_gt : x ≥ n := by exact lt_succ.mp (le.step cond)
have IH' := IH x
cases IH'
rename_i left right
have right' := right succ_gt
rw [tsum_fintype]
rw [Fin.sum_univ_castSucc]
simp
constructor
. simp at right'
intro x'
apply right' x'
. have neq : x ≠ n := by exact Nat.ne_of_gt cond
simp [neq]


/--
Computes the sum of an indicator variable (indicating inside the support of ``Fin n``) over the space ``Fin n``.
-/
theorem sum_indicator_finrange (n : Nat) (x : Nat) (h : x < n) :
(∑' (i : Fin n), @ite ENNReal (x = ↑i) (propDecidable (x = ↑i)) 1 0) = (1 : ENNReal) := by
have H := sum_indicator_finrange_gen n x
cases H
rename_i left right
apply left
trivial

/--
Evaluates the ``probUniformP2`` distribution at a point inside of its support.
-/
@[simp]
theorem probUniformP2_apply (n : PNat) (x : Nat) (h : x < 2 ^ (log 2 n)) :
theorem UniformPowerOfTwoSample_apply (n : PNat) (x : Nat) (h : x < 2 ^ (log 2 n)) :
(UniformPowerOfTwoSample n) x = 1 / (2 ^ (log 2 n)) := by
simp only [UniformPowerOfTwoSample, Lean.Internal.coeM, Bind.bind, Pure.pure, CoeT.coe,
CoeHTCT.coe, CoeHTC.coe, CoeOTC.coe, CoeOut.coe, toSLang_apply, PMF.bind_apply,
uniformOfFintype_apply, Fintype.card_fin, cast_pow, cast_ofNat, PMF.pure_apply, one_div]
rw [ENNReal.tsum_mul_left]
rw [sum_indicator_finrange (2 ^ (log 2 n)) x]
. simp
. trivial
simp [UniformPowerOfTwoSample]
rw [probUniformP2_eval_support]
· simp
trivial

/--
Evaluates the ``probUniformP2`` distribution at a point outside of its support
-/
@[simp]
theorem probUniformP2_apply' (n : PNat) (x : Nat) (h : x ≥ 2 ^ (log 2 n)) :
theorem UniformPowerOfTwoSample_apply' (n : PNat) (x : Nat) (h : x ≥ 2 ^ (log 2 n)) :
UniformPowerOfTwoSample n x = 0 := by
simp [UniformPowerOfTwoSample]
intro i
cases i
rename_i i P
simp only
have A : i < 2 ^ log 2 ↑n ↔ ¬ i ≥ 2 ^ log 2 ↑n := by exact lt_iff_not_le
rw [A] at P
simp at P
by_contra CONTRA
subst CONTRA
replace A := A.1 P
contradiction
rw [probUniformP2_eval_zero]
trivial

lemma if_simpl_up2 (n : PNat) (x x_1: Fin (2 ^ log 2 ↑n)) :
(@ite ENNReal (x_1 = x) (propDecidable (x_1 = x)) 0 (@ite ENNReal ((@Fin.val (2 ^ log 2 ↑n) x) = (@Fin.val (2 ^ log 2 ↑n) x_1)) (propDecidable ((@Fin.val (2 ^ log 2 ↑n) x) = (@Fin.val (2 ^ log 2 ↑n) x_1))) 1 0)) = 0 := by
Expand All @@ -135,32 +61,23 @@ lemma if_simpl_up2 (n : PNat) (x x_1: Fin (2 ^ log 2 ↑n)) :
/--
The ``SLang`` term ``uniformPowerOfTwo`` is a proper distribution on ``ℕ``.
-/
theorem probUniformP2_normalizes (n : PNat) :
theorem UniformPowerOfTwoSample_normalizes (n : PNat) :
∑' i : ℕ, UniformPowerOfTwoSample n i = 1 := by
rw [UniformPowerOfTwoSample]
rw [← @sum_add_tsum_nat_add' _ _ _ _ _ _ (2 ^ (log 2 n))]
. simp only [ge_iff_le, le_add_iff_nonneg_left, _root_.zero_le, probUniformP2_apply',
tsum_zero, add_zero]
simp only [UniformPowerOfTwoSample, Lean.Internal.coeM, Bind.bind, Pure.pure, CoeT.coe,
CoeHTCT.coe, CoeHTC.coe, CoeOTC.coe, CoeOut.coe, toSLang_apply, PMF.bind_apply,
uniformOfFintype_apply, Fintype.card_fin, cast_pow, cast_ofNat, PMF.pure_apply]
rw [Finset.sum_range]
· rw [Finset.sum_range]
conv =>
left
right
intro x
rw [ENNReal.tsum_mul_left]
rw [ENNReal.tsum_eq_add_tsum_ite x]
right
right
right
intro x_1
rw [if_simpl_up2]
enter [1]
congr
· enter [2, a]
skip
rw [probUniformP2_eval_support (by exact a.isLt)]
· enter [1, a]
rw [probUniformP2_eval_zero (by exact Nat.le_add_left (2 ^ log 2 ↑n) a)]
simp
rw [ENNReal.inv_pow]
rw [← mul_pow]
rw [two_mul]
rw [ENNReal.inv_two_add_inv_two]
rw [one_pow]
. exact ENNReal.summable
apply ENNReal.mul_inv_cancel
· simp
· simp
exact ENNReal.summable

end SLang
22 changes: 20 additions & 2 deletions SampCert/SLang.lean
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,31 @@ Uniform distribution on a byte
@[extern "prob_UniformByte"]
def probUniformByte : SLang UInt8 := (fun _ => 1 / UInt8.size)

/--
Upper i bits from a unifomly sampled byte
-/
def probUniformByteUpperBits (i : ℕ) : SLang ℕ := do
let w <- probUniformByte
return w.toNat.shiftRight (8 - i)

/--
Uniform distribution on the set [0, 2^i) ⊆ ℕ
-/
def probUniformP2 (i : ℕ) : SLang ℕ :=
if (i < 8)
then probUniformByteUpperBits i
else do
let v <- probUniformByte
let w <- probUniformP2 (i - 8)
return UInt8.size * w + v.toNat

/--
``SLang`` value for the uniform distribution over ``m`` elements, where
the number``m`` is the largest power of two that is at most ``n``.
-/
@[extern "prob_UniformP2"]
def UniformPowerOfTwoSample (n : ℕ+) : SLang ℕ :=
toSLang (PMF.uniformOfFintype (Fin (2 ^ (log 2 n))))
probUniformP2 (log 2 n)


/--
``SLang`` functional which executes ``body`` only when ``cond`` is ``false``.
Expand Down
66 changes: 1 addition & 65 deletions SampCert/Samplers/Laplace/Properties.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Jean-Baptiste Tristan
-/
import SampCert.Util.Util
import SampCert.Foundations.Basic
import SampCert.Samplers.Uniform.Basic
import SampCert.Samplers.Bernoulli.Basic
Expand Down Expand Up @@ -1075,24 +1076,6 @@ lemma partial_geometric_series {p : ENNReal} (HP2 : p < 1) (B : ℕ) :
rfl


lemma nat_div_eq_le_lt_iff {a b c : ℕ} (Hc : 0 < c) : a = b / c <-> (a * c ≤ b ∧ b < (a + 1) * c) := by
apply Iff.intro
· intro H
apply And.intro
· apply (Nat.le_div_iff_mul_le Hc).mp
exact Nat.le_of_eq H
· apply (Nat.div_lt_iff_lt_mul Hc).mp
apply Nat.lt_succ_iff.mpr
exact Nat.le_of_eq (id (Eq.symm H))
· intro ⟨ H1, H2 ⟩
apply LE.le.antisymm
· apply (Nat.le_div_iff_mul_le Hc).mpr
apply H1
· apply Nat.lt_succ_iff.mp
simp
apply (Nat.div_lt_iff_lt_mul Hc).mpr
apply H2

/--
Integer division of a geometric distribution is a geometric distribution
-/
Expand Down Expand Up @@ -1287,53 +1270,6 @@ lemma geo_div_geo (k n : ℕ) (p : ENNReal) (Hp : p < 1) (Hn : 0 < n) :
exact succ_mul k n


/--
Specialize Euclidean division from ℤ to ℕ
-/
lemma euclidean_division (n : ℕ) {D : ℕ} (HD : 0 < D) :
∃ q r : ℕ, (r < D) ∧ n = r + D * q := by
exists (n / D)
exists (n % D)
apply And.intro
· exact mod_lt n HD
· apply ((@Nat.cast_inj ℤ).mp)
simp
conv =>
lhs
rw [<- EuclideanDomain.mod_add_div (n : ℤ) (D : ℤ)]

/--
Euclidiean division is unique
-/
lemma euclidean_division_uniquness (r1 r2 q1 q2 : ℕ) {D : ℕ} (HD : 0 < D) (Hr1 : r1 < D) (Hr2 : r2 < D) :
r1 + D * q1 = r2 + D * q2 <-> (r1 = r2 ∧ q1 = q2) := by
apply Iff.intro
· intro H
cases (Classical.em (r1 = r2))
· aesop
cases (Classical.em (q1 = q2))
· aesop
rename_i Hne1 Hne2
exfalso

have Contra1 (W X Y Z : ℕ) (HY : Y < D) (HK : W < X) : (Y + D * W < Z + D * X) := by
suffices (D * W < D * X) by
have A : (1 + W ≤ X) := by exact one_add_le_iff.mpr HK
have _ : (D * (1 + W) ≤ D * X) := by exact Nat.mul_le_mul_left D A
have _ : (D + D * W ≤ D * X) := by linarith
have _ : (Y + D * W < D * X) := by linarith
have _ : (Y + D * W < Z + D * X) := by linarith
assumption
exact Nat.mul_lt_mul_of_pos_left HK HD

rcases (lt_trichotomy q1 q2) with HK' | ⟨ HK' | HK' ⟩
· exact (LT.lt.ne (Contra1 q1 q2 r1 r2 Hr1 HK') H)
· exact Hne2 HK'
· apply (LT.lt.ne (Contra1 q2 q1 r2 r1 Hr2 HK') (Eq.symm H))

· intro ⟨ _, _ ⟩
simp_all

/--
Equivalence between sampling loops
-/
Expand Down
6 changes: 3 additions & 3 deletions SampCert/Samplers/Uniform/Properties.lean
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ lemma rw_ite (n : PNat) (x : Nat) :
(if x < n then (UniformPowerOfTwoSample (2 * n)) x else 0)
= if x < n then 1 / 2 ^ log 2 ((2 : PNat) * n) else 0 := by
split
rw [probUniformP2_apply]
rw [UniformPowerOfTwoSample_apply]
simp only [PNat.mul_coe, one_div]
apply double_large_enough
trivial
Expand All @@ -100,7 +100,7 @@ lemma uniformPowerOfTwoSample_autopilot (n : PNat) :
= ∑' (i : ℕ), if i < ↑n then UniformPowerOfTwoSample (2 * n) i else 0 := by
have X : (∑' (i : ℕ), if decide (↑n ≤ i) = true then UniformPowerOfTwoSample (2 * n) i else 0) +
(∑' (i : ℕ), if decide (↑n ≤ i) = false then UniformPowerOfTwoSample (2 * n) i else 0) = 1 := by
have A := probUniformP2_normalizes (2 * n)
have A := UniformPowerOfTwoSample_normalizes (2 * n)
have B := @tsum_add_tsum_compl ENNReal ℕ _ _ (fun i => UniformPowerOfTwoSample (2 * n) i) _ _ { i : ℕ | decide (↑n ≤ i) = true} ENNReal.summable ENNReal.summable
rw [A] at B
clear A
Expand All @@ -114,7 +114,7 @@ lemma uniformPowerOfTwoSample_autopilot (n : PNat) :
trivial
apply ENNReal.sub_eq_of_eq_add_rev
. have Y := tsum_split_less (fun i => ↑n ≤ i) (fun i => UniformPowerOfTwoSample (2 * n) i)
rw [probUniformP2_normalizes (2 * n)] at Y
rw [UniformPowerOfTwoSample_normalizes (2 * n)] at Y
simp at Y
clear X
by_contra
Expand Down
66 changes: 66 additions & 0 deletions SampCert/Util/Util.lean
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,69 @@ theorem tsum_shift'_2 (f : ℕ → ENNReal) :
right
rw [sum_range_succ]
rw [← IH]

/--
Specialize Euclidean division from ℤ to ℕ
-/
lemma euclidean_division (n : ℕ) {D : ℕ} (HD : 0 < D) :
∃ q r : ℕ, (r < D) ∧ n = r + D * q := by
exists (n / D)
exists (n % D)
apply And.intro
· exact mod_lt n HD
· apply ((@Nat.cast_inj ℤ).mp)
simp
conv =>
lhs
rw [<- EuclideanDomain.mod_add_div (n : ℤ) (D : ℤ)]

/--
Euclidiean division is unique
-/
lemma euclidean_division_uniquness (r1 r2 q1 q2 : ℕ) {D : ℕ} (HD : 0 < D) (Hr1 : r1 < D) (Hr2 : r2 < D) :
r1 + D * q1 = r2 + D * q2 <-> (r1 = r2 ∧ q1 = q2) := by
apply Iff.intro
· intro H
cases (Classical.em (r1 = r2))
· aesop
cases (Classical.em (q1 = q2))
· aesop
rename_i Hne1 Hne2
exfalso

have Contra1 (W X Y Z : ℕ) (HY : Y < D) (HK : W < X) : (Y + D * W < Z + D * X) := by
suffices (D * W < D * X) by
have A : (1 + W ≤ X) := by exact one_add_le_iff.mpr HK
have _ : (D * (1 + W) ≤ D * X) := by exact Nat.mul_le_mul_left D A
have _ : (D + D * W ≤ D * X) := by linarith
have _ : (Y + D * W < D * X) := by linarith
have _ : (Y + D * W < Z + D * X) := by linarith
assumption
exact Nat.mul_lt_mul_of_pos_left HK HD

rcases (lt_trichotomy q1 q2) with HK' | ⟨ HK' | HK' ⟩
· exact (LT.lt.ne (Contra1 q1 q2 r1 r2 Hr1 HK') H)
· exact Hne2 HK'
· apply (LT.lt.ne (Contra1 q2 q1 r2 r1 Hr2 HK') (Eq.symm H))

· intro ⟨ _, _ ⟩
simp_all


lemma nat_div_eq_le_lt_iff {a b c : ℕ} (Hc : 0 < c) : a = b / c <-> (a * c ≤ b ∧ b < (a + 1) * c) := by
apply Iff.intro
· intro H
apply And.intro
· apply (Nat.le_div_iff_mul_le Hc).mp
exact Nat.le_of_eq H
· apply (Nat.div_lt_iff_lt_mul Hc).mp
apply Nat.lt_succ_iff.mpr
exact Nat.le_of_eq (id (Eq.symm H))
· intro ⟨ H1, H2 ⟩
apply LE.le.antisymm
· apply (Nat.le_div_iff_mul_le Hc).mpr
apply H1
· apply Nat.lt_succ_iff.mp
simp
apply (Nat.div_lt_iff_lt_mul Hc).mpr
apply H2
Loading