Skip to content

Commit

Permalink
Fixes arguments in functions of the protocol.
Browse files Browse the repository at this point in the history
  • Loading branch information
armfazh committed Sep 14, 2022
1 parent 284af11 commit 95cb829
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 100 deletions.
149 changes: 77 additions & 72 deletions poc/oprf.sage
Original file line number Diff line number Diff line change
Expand Up @@ -41,56 +41,58 @@ class OPRFClientContext(Context):
def identifier(self):
return self.identifier

def blind(self, x, rng):
def blind(self, input, rng):
blind = ZZ(self.suite.group.random_scalar(rng))
input_element = self.suite.group.hash_to_group(x, self.group_domain_separation_tag())
input_element = self.suite.group.hash_to_group(input, self.group_domain_separation_tag())
if input_element == self.suite.group.identity():
raise Exception("InvalidInputError")
blinded_element = blind * input_element
return blind, blinded_element

def unblind(self, blind, evaluated_element, blinded_element, proof):
def unblind(self, blind, evaluated_element):
blind_inv = inverse_mod(blind, self.suite.group.order())
N = blind_inv * evaluated_element
unblinded_element = self.suite.group.serialize(N)
return unblinded_element

def finalize(self, x, blind, evaluated_element, blinded_element, proof, info):
unblinded_element = self.unblind(blind, evaluated_element, blinded_element, proof)
finalize_input = I2OSP(len(x), 2) + x \
def finalize(self, input, blind, evaluated_element):
unblinded_element = self.unblind(blind, evaluated_element)
finalize_input = I2OSP(len(input), 2) + input \
+ I2OSP(len(unblinded_element), 2) + unblinded_element \
+ _as_bytes("Finalize")

return self.suite.hash(finalize_input)

class OPRFServerContext(Context):
def __init__(self, version, mode, suite, skS, pkS):
def __init__(self, version, mode, suite, skS):
Context.__init__(self, version, mode, suite)
self.skS = skS
self.pkS = pkS

def internal_evaluate(self, blinded_element):
evaluated_element = self.skS * blinded_element
return evaluated_element

def blind_evaluate(self, blinded_element, info, rng):
def blind_evaluate(self, blinded_element, rng):
evaluated_element = self.internal_evaluate(blinded_element)
return evaluated_element, None, None
return evaluated_element

def evaluate_without_proof(self, blinded_element, info):
def evaluate_without_proof(self, blinded_element):
return self.internal_evaluate(blinded_element)

def evaluate(self, x, info):
input_element = self.suite.group.hash_to_group(x, self.group_domain_separation_tag())
def evaluate(self, input, expected_output):
input_element = self.suite.group.hash_to_group(input, self.group_domain_separation_tag())
if input_element == self.suite.group.identity():
raise Exception("InvalidInputError")
evaluated_element = self.internal_evaluate(input_element)
issued_element = self.suite.group.serialize(evaluated_element)
finalize_input = I2OSP(len(x), 2) + x \

finalize_input = I2OSP(len(input), 2) + input \
+ I2OSP(len(issued_element), 2) + issued_element \
+ _as_bytes("Finalize")

return self.suite.hash(finalize_input)
digest = self.suite.hash(finalize_input)

return (digest == expected_output)

class Verifiable(object):
def compute_composites_inner(self, k, B, Cs, Ds):
Expand Down Expand Up @@ -138,12 +140,12 @@ class VOPRFClientContext(OPRFClientContext,Verifiable):
self.pkS = pkS

def verify_proof(self, A, B, Cs, Ds, proof):
a = self.compute_composites(B, Cs, Ds)
[M, Z] = self.compute_composites(B, Cs, Ds)
c = proof[0]
s = proof[1]

M = a[0]
Z = a[1]
t2 = (proof[1] * A) + (proof[0] * B)
t3 = (proof[1] * M) + (proof[0] * Z)
t2 = (s * A) + (c * B)
t3 = (s * M) + (c * Z)

Bm = self.suite.group.serialize(B)
a0 = self.suite.group.serialize(M)
Expand All @@ -158,10 +160,10 @@ class VOPRFClientContext(OPRFClientContext,Verifiable):
+ I2OSP(len(a3), 2) + a3 \
+ _as_bytes("Challenge")

c = self.suite.group.hash_to_scalar(h2s_input, self.scalar_domain_separation_tag())
expectedC = self.suite.group.hash_to_scalar(h2s_input, self.scalar_domain_separation_tag())

assert(c == proof[0])
return c == proof[0]
assert(expectedC == c)
return expectedC == c

def unblind(self, blind, evaluated_element, blinded_element, proof):
G = self.suite.group.generator()
Expand Down Expand Up @@ -190,41 +192,41 @@ class VOPRFClientContext(OPRFClientContext,Verifiable):

return unblinded_elements

def finalize(self, x, blind, evaluated_element, blinded_element, proof, info):
def finalize(self, input, blind, evaluated_element, blinded_element, proof):
unblinded_element = self.unblind(blind, evaluated_element, blinded_element, proof)
finalize_input = I2OSP(len(x), 2) + x \
finalize_input = I2OSP(len(input), 2) + input \
+ I2OSP(len(unblinded_element), 2) + unblinded_element \
+ _as_bytes("Finalize")

return self.suite.hash(finalize_input)

def finalize_batch(self, xs, blinds, evaluated_elements, blinded_elements, proof, info):
def finalize_batch(self, inputs, blinds, evaluated_elements, blinded_elements, proof):
assert(len(inputs) == len(blinds))
assert(len(blinds) == len(evaluated_elements))
assert(len(evaluated_elements) == len(blinded_elements))

unblinded_elements = self.unblind_batch(blinds, evaluated_elements, blinded_elements, proof)

outputs = []
for i, unblinded_element in enumerate(unblinded_elements):
finalize_input = I2OSP(len(xs[i]), 2) + xs[i] \
finalize_input = I2OSP(len(inputs[i]), 2) + inputs[i] \
+ I2OSP(len(unblinded_element), 2) + unblinded_element \
+ _as_bytes("Finalize")

digest = self.suite.hash(finalize_input)
outputs.append(digest)
output = self.suite.hash(finalize_input)
outputs.append(output)

return outputs

class VOPRFServerContext(OPRFServerContext,Verifiable):
def __init__(self, version, mode, suite, skS, pkS):
OPRFServerContext.__init__(self, version, mode, suite, skS, pkS)
OPRFServerContext.__init__(self, version, mode, suite, skS)
self.pkS = pkS

def generate_proof(self, k, A, B, Cs, Ds, rng):
a = self.compute_composites_fast(k, B, Cs, Ds)
[M, Z] = self.compute_composites_fast(k, B, Cs, Ds)

r = ZZ(self.suite.group.random_scalar(rng))
M = a[0]
Z = a[1]
t2 = r * A
t3 = r * M

Expand All @@ -250,15 +252,15 @@ class VOPRFServerContext(OPRFServerContext,Verifiable):
evaluated_element = self.skS * blinded_element
return evaluated_element

def blind_evaluate(self, blinded_element, info, rng):
def blind_evaluate(self, blinded_element, rng):
evaluated_element = self.internal_evaluate(blinded_element)
proof, r = self.generate_proof(self.skS, self.suite.group.generator(), self.pkS, [blinded_element], [evaluated_element], rng)
return evaluated_element, proof, r

def evaluate_without_proof(self, blinded_element, info):
def evaluate_without_proof(self, blinded_element):
return self.internal_evaluate(blinded_element)

def blind_evaluate_batch(self, blinded_elements, info, rng):
def blind_evaluate_batch(self, blinded_elements, rng):
evaluated_elements = []
for blinded_element in blinded_elements:
evaluated_element = self.skS * blinded_element
Expand All @@ -272,16 +274,16 @@ class POPRFClientContext(VOPRFClientContext):
VOPRFClientContext.__init__(self, version, mode, suite, pkS)
self.pkS = pkS

def blind(self, x, info, rng):
context = _as_bytes("Info") + I2OSP(len(info), 2) + info
t = self.suite.group.hash_to_scalar(context, self.scalar_domain_separation_tag())
G = self.suite.group.generator()
tweaked_key = (G * t) + self.pkS
def blind(self, input, info, rng):
framed_info = _as_bytes("Info") + I2OSP(len(info), 2) + info
m = self.suite.group.hash_to_scalar(framed_info, self.scalar_domain_separation_tag())
T = m * self.suite.group.generator()
tweaked_key = T + self.pkS
if tweaked_key == self.suite.group.identity():
raise Exception("InvalidInputError")

blind = ZZ(self.suite.group.random_scalar(rng))
input_element = self.suite.group.hash_to_group(x, self.group_domain_separation_tag())
input_element = self.suite.group.hash_to_group(input, self.group_domain_separation_tag())
if input_element == self.suite.group.identity():
raise Exception("InvalidInputError")

Expand Down Expand Up @@ -315,30 +317,31 @@ class POPRFClientContext(VOPRFClientContext):

return unblinded_elements

def finalize(self, x, blind, evaluated_element, blinded_element, proof, info, tweaked_key):
def finalize(self, input, blind, evaluated_element, blinded_element, proof, info, tweaked_key):
unblinded_element = self.unblind(blind, evaluated_element, blinded_element, proof, tweaked_key)
finalize_input = I2OSP(len(x), 2) + x \
finalize_input = I2OSP(len(input), 2) + input \
+ I2OSP(len(info), 2) + info \
+ I2OSP(len(unblinded_element), 2) + unblinded_element \
+ _as_bytes("Finalize")

return self.suite.hash(finalize_input)

def finalize_batch(self, xs, blinds, evaluated_elements, blinded_elements, proof, info, tweaked_key):
def finalize_batch(self, inputs, blinds, evaluated_elements, blinded_elements, proof, info, tweaked_key):
assert(len(inputs) == len(blinds))
assert(len(blinds) == len(evaluated_elements))
assert(len(evaluated_elements) == len(blinded_elements))

unblinded_elements = self.unblind_batch(blinds, evaluated_elements, blinded_elements, proof, tweaked_key)

outputs = []
for i, unblinded_element in enumerate(unblinded_elements):
finalize_input = I2OSP(len(xs[i]), 2) + xs[i] \
finalize_input = I2OSP(len(inputs[i]), 2) + inputs[i] \
+ I2OSP(len(info), 2) + info \
+ I2OSP(len(unblinded_element), 2) + unblinded_element \
+ _as_bytes("Finalize")

digest = self.suite.hash(finalize_input)
outputs.append(digest)
output = self.suite.hash(finalize_input)
outputs.append(output)

return outputs

Expand All @@ -347,56 +350,58 @@ class POPRFServerContext(VOPRFServerContext):
VOPRFServerContext.__init__(self, version, mode, suite, skS, pkS)

def internal_evaluate(self, blinded_element, info):
context = _as_bytes("Info") + I2OSP(len(info), 2) + info
t = self.suite.group.hash_to_scalar(context, self.scalar_domain_separation_tag())
k = self.skS + t
if int(k) == 0:
framed_info = _as_bytes("Info") + I2OSP(len(info), 2) + info
m = self.suite.group.hash_to_scalar(framed_info, self.scalar_domain_separation_tag())
t = (self.skS + m) % self.suite.group.order()
if int(t) == 0:
raise Exception("InverseError")
k_inv = inverse_mod(k, self.suite.group.order())
evaluated_element = k_inv * blinded_element
t_inv = inverse_mod(t, self.suite.group.order())
evaluated_element = t_inv * blinded_element

return evaluated_element, k
return evaluated_element, t

def blind_evaluate(self, blinded_element, info, rng):
evaluated_element, k = self.internal_evaluate(blinded_element, info)
evaluated_element, t = self.internal_evaluate(blinded_element, info)
G = self.suite.group.generator()
U = k * G
proof, r = self.generate_proof(k, G, U, [evaluated_element], [blinded_element], rng)
tweaked_key = t * G
proof, r = self.generate_proof(t, G, tweaked_key, [evaluated_element], [blinded_element], rng)
return evaluated_element, proof, r

def evaluate_without_proof(self, blinded_element, info):
evaluated_element, _ = self.internal_evaluate(blinded_element, info)
return evaluated_element

def blind_evaluate_batch(self, blinded_elements, info, rng):
context = _as_bytes("Info") + I2OSP(len(info), 2) + info
t = self.suite.group.hash_to_scalar(context, self.scalar_domain_separation_tag())
framed_info = _as_bytes("Info") + I2OSP(len(info), 2) + info
m = self.suite.group.hash_to_scalar(framed_info, self.scalar_domain_separation_tag())

evaluated_elements = []
for blinded_element in blinded_elements:
k = self.skS + t
if int(k) == 0:
t = (self.skS + m) % self.suite.group.order()
if int(t) == 0:
raise Exception("InverseError")
k_inv = inverse_mod(k, self.suite.group.order())
evaluated_element = k_inv * blinded_element
t_inv = inverse_mod(t, self.suite.group.order())
evaluated_element = t_inv * blinded_element
evaluated_elements.append(evaluated_element)

G = self.suite.group.generator()
U = k * G
proof, r = self.generate_proof(k, G, U, evaluated_elements, blinded_elements, rng)
tweaked_key = t * G
proof, r = self.generate_proof(t, G, tweaked_key, evaluated_elements, blinded_elements, rng)
return evaluated_elements, proof, r

def evaluate(self, x, info):
input_element = self.suite.group.hash_to_group(x, self.group_domain_separation_tag())
def evaluate(self, input, expected_output, info):
input_element = self.suite.group.hash_to_group(input, self.group_domain_separation_tag())
evaluated_element = self.evaluate_without_proof(input_element, info)
issued_element = self.suite.group.serialize(evaluated_element)

finalize_input = I2OSP(len(x), 2) + x \
finalize_input = I2OSP(len(input), 2) + input \
+ I2OSP(len(info), 2) + info \
+ I2OSP(len(issued_element), 2) + issued_element \
+ _as_bytes("Finalize")

return self.suite.hash(finalize_input)
output = self.suite.hash(finalize_input)

return (output == expected_output)

MODE_OPRF = 0x00
MODE_VOPRF = 0x01
Expand All @@ -419,7 +424,7 @@ def DeriveKeyPair(mode, suite, seed, info):
return skS, pkS

def SetupOPRFServer(suite, skS):
return OPRFServerContext(VERSION, MODE_OPRF, suite, skS, None)
return OPRFServerContext(VERSION, MODE_OPRF, suite, skS)

def SetupOPRFClient(suite):
return OPRFClientContext(VERSION, MODE_OPRF, suite)
Expand All @@ -446,7 +451,7 @@ ciphersuite_p521_sha512 = 0x0005

oprf_ciphersuites = {
ciphersuite_ristretto255_sha512: Ciphersuite("OPRF(ristretto255, SHA-512)", ciphersuite_ristretto255_sha512, GroupRistretto255(), hashlib.sha512, lambda x : hashlib.sha512(x).digest()),
ciphersuite_decaf448_shake256: Ciphersuite("OPRF(decaf448, SHAKE256)", ciphersuite_decaf448_shake256, GroupDecaf448(), hashlib.shake_256, lambda x : hashlib.shake_256(x).digest(int(64))),
ciphersuite_decaf448_shake256: Ciphersuite("OPRF(decaf448, SHAKE-256)", ciphersuite_decaf448_shake256, GroupDecaf448(), hashlib.shake_256, lambda x : hashlib.shake_256(x).digest(int(64))),
ciphersuite_p256_sha256: Ciphersuite("OPRF(P-256, SHA-256)", ciphersuite_p256_sha256, GroupP256(), hashlib.sha256, lambda x : hashlib.sha256(x).digest()),
ciphersuite_p384_sha384: Ciphersuite("OPRF(P-384, SHA-384)", ciphersuite_p384_sha384, GroupP384(), hashlib.sha384, lambda x : hashlib.sha384(x).digest()),
ciphersuite_p521_sha512: Ciphersuite("OPRF(P-521, SHA-512)", ciphersuite_p521_sha512, GroupP521(), hashlib.sha512, lambda x : hashlib.sha512(x).digest()),
Expand Down
Loading

0 comments on commit 95cb829

Please sign in to comment.