Skip to content

Commit

Permalink
Merge pull request #285 from getamis/supportTaproot-3
Browse files Browse the repository at this point in the history
Support Taproot Schnorr Signature
  • Loading branch information
markya0616 authored Nov 22, 2023
2 parents 5308d42 + b78a9b3 commit 09a854b
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 67 deletions.
7 changes: 7 additions & 0 deletions crypto/ecpointgrouplaw/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ func NewBase(curve elliptic.Curve) *ECPoint {
}
}

func (p *ECPoint) IsEvenY() bool {
if p.IsIdentity() {
return true
}
return new(big.Int).And(p.y, big1).Cmp(big1) != 0
}

// IsIdentity checks if the point is the identity element.
func (p *ECPoint) IsIdentity() bool {
return isIdentity(p.x, p.y)
Expand Down
179 changes: 133 additions & 46 deletions crypto/tss/eddsa/frost/signer/round_1.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package signer

import (
"crypto/sha256"
"crypto/sha512"
"errors"
"math/big"
Expand All @@ -24,6 +25,7 @@ import (
"github.com/getamis/alice/crypto/birkhoffinterpolation"
"github.com/getamis/alice/crypto/commitment"
"github.com/getamis/alice/crypto/ecpointgrouplaw"
"github.com/getamis/alice/crypto/elliptic"
"github.com/getamis/alice/crypto/homo"
"github.com/getamis/alice/crypto/tss/dkg"
"github.com/getamis/alice/crypto/tss/ecdsa/cggmp"
Expand All @@ -40,9 +42,8 @@ const (
)

var (
bit254 = new(big.Int).Lsh(big.NewInt(1), 253)
big0 = big.NewInt(0)
big1 = big.NewInt(1)
big0 = big.NewInt(0)
big1 = big.NewInt(1)

//ErrExceedMaxRetry is returned if we retried over times
ErrExceedMaxRetry = errors.New("exceed max retries")
Expand All @@ -54,6 +55,12 @@ var (
ErrTrivialSignature = errors.New("obtain trivial signature")
//ErrTrivialShaResult is returned if the output of SHAPoint is trivial.
ErrTrivialShaResult = errors.New("the output of SHAPoint is trivial")
//ErrNotSupportCurve is returned if the curve is not support.
ErrNotSupportCurve = errors.New("if the curve is not support")
//ErrTrivialPoint is returned if the point is trivial.
ErrTrivialPoint = errors.New("the point is trivial")
//ErrNotCorrectMessage is returned if the message is not correct.
ErrNotCorrectMessage = errors.New("the message is not correct")
)

type pubkeyData struct {
Expand Down Expand Up @@ -245,20 +252,24 @@ func (p *round1) Finalize(logger log.Logger) (types.Handler, error) {
return nil, ErrTrivialSignature
}
p.c, err = SHAPoints(p.pubKey, R, p.message)

if err != nil {
return nil, err
}
p.r = R

// Compute own zi = di+ ei*li + c bi xi
selfNode := p.nodes[p.peerManager.SelfID()]
share := new(big.Int).Set(p.share)
p.d, p.e, share, err = computeDEShareTaproot(p.d, p.e, share, R, p.pubKey)
if err != nil {
return nil, err
}
z := new(big.Int).Mul(p.e, selfNode.ell)
temp := new(big.Int).Mul(p.c, selfNode.coBk)
temp = temp.Mul(temp, p.share)
temp = temp.Mul(temp, share)
z.Add(z, temp)
z.Add(z, p.d)
z.Mod(z, p.curveN)

// Broadcast round2 message
round2Msg := &Message{
Id: p.peerManager.SelfID(),
Expand Down Expand Up @@ -286,61 +297,136 @@ func getMessage(messsage types.Message) *Message {
return messsage.(*Message)
}

func SHAPoints(pubKey, R *ecpointgrouplaw.ECPoint, message []byte) (*big.Int, error) {
encodedR := ecpointEncoding(R)
encodedPubKey := ecpointEncoding(pubKey)
h := sha512.New()
h.Write(encodedR[:])
// Different curves for Schnorr signature has different rules.
func computeDEShareTaproot(d, e, s *big.Int, R, pubKey *ecpointgrouplaw.ECPoint) (*big.Int, *big.Int, *big.Int, error) {
curve := R.GetCurve()
switch curve {
// Taproot verification
case elliptic.Secp256k1():
if !R.IsEvenY() {
d, e = d.Sub(curve.Params().N, d), e.Sub(curve.Params().N, e)
}
if !pubKey.IsEvenY() {
s = s.Sub(curve.Params().N, s)
}
return d, e, s, nil
case elliptic.Ed25519():
return d, e, s, nil
}
return nil, nil, nil, ErrNotSupportCurve
}

h.Write(encodedPubKey[:])
h.Write(message)
digest := h.Sum(nil)
result := new(big.Int).SetBytes(utils.ReverseByte(digest))
result = result.Mod(result, R.GetCurve().Params().N)
if result.Cmp(big0) == 0 {
return nil, ErrTrivialShaResult
func SHAPoints(pubKey, R *ecpointgrouplaw.ECPoint, message []byte) (*big.Int, error) {
curveType := pubKey.GetCurve()
if R.IsIdentity() || pubKey.IsIdentity() {
return nil, ErrTrivialPoint
}
switch curveType {
case elliptic.Secp256k1():
// e = int(hashBIP0340/challenge(bytes(R) || bytes(P) || m)) mod n
if len(message) != 32 {
return nil, ErrNotCorrectMessage
}
hash := make([]byte, 0)
hash = append(hash, utils.Bytes32(R.GetX())...)
hash = append(hash, utils.Bytes32(pubKey.GetX())...)
hash = append(hash, utils.Pad(message, 32)...)

sha256Hash := sha256.Sum256([]byte("BIPSchnorr"))
sha256HashInput := sha256Hash[:]
sha256HashInput = append(sha256HashInput, sha256HashInput[:]...)
sha256HashInput = append(sha256HashInput, hash...)
digest := sha256.Sum256(sha256HashInput)
result := new(big.Int).SetBytes(digest[:])

result.Mod(result, R.GetCurve().Params().N)
if result.Cmp(big0) == 0 {
return nil, ErrTrivialShaResult
}
return result, nil

return result, nil
}
case elliptic.Ed25519():
encodedR, err := ecpointEncoding(R)
if err != nil {
return nil, err
}
encodedPubKey, err := ecpointEncoding(pubKey)
if err != nil {
return nil, err
}

h := sha512.New()
h.Write(encodedR[:])

func ecpointEncoding(pt *ecpointgrouplaw.ECPoint) *[32]byte {
var result, X, Y [32]byte
var x, y edwards25519.FieldElement
if pt.Equal(ecpointgrouplaw.NewIdentity(pt.GetCurve())) {
// TODO: We need to check this
Y[0] = 1
} else {
tempX := pt.GetX().Bytes()
tempY := pt.GetY().Bytes()

for i := 0; i < len(tempX); i++ {
index := len(tempX) - 1 - i
X[index] = tempX[i]
h.Write(encodedPubKey[:])
h.Write(message)
digest := h.Sum(nil)
result := new(big.Int).SetBytes(utils.ReverseByte(digest))
result = result.Mod(result, R.GetCurve().Params().N)
if result.Cmp(big0) == 0 {
return nil, ErrTrivialShaResult
}
for i := 0; i < len(tempY); i++ {
index := len(tempY) - 1 - i
Y[index] = tempY[i]

return result, nil
}
return nil, ErrNotSupportCurve
}

func ecpointEncoding(pt *ecpointgrouplaw.ECPoint) ([32]byte, error) {
curveType := pt.GetCurve()
nullSlice := [32]byte{}
if pt.IsIdentity() {
return nullSlice, ErrTrivialPoint
}
switch curveType {
case elliptic.Secp256k1():
return ([32]byte)(utils.Bytes32(pt.GetX())), nil
case elliptic.Ed25519():
var result, X, Y [32]byte
var x, y edwards25519.FieldElement
if pt.Equal(ecpointgrouplaw.NewIdentity(pt.GetCurve())) {
// TODO: We need to check this
Y[0] = 1
} else {
tempX := pt.GetX().Bytes()
tempY := pt.GetY().Bytes()

for i := 0; i < len(tempX); i++ {
index := len(tempX) - 1 - i
X[index] = tempX[i]
}
for i := 0; i < len(tempY); i++ {
index := len(tempY) - 1 - i
Y[index] = tempY[i]
}
}
edwards25519.FeFromBytes(&x, &X)
edwards25519.FeFromBytes(&y, &Y)
edwards25519.FeToBytes(&result, &y)
result[31] ^= edwards25519.FeIsNegative(&x) << 7
return result, nil
}
edwards25519.FeFromBytes(&x, &X)
edwards25519.FeFromBytes(&y, &Y)
edwards25519.FeToBytes(&result, &y)
result[31] ^= edwards25519.FeIsNegative(&x) << 7
return &result
return nullSlice, ErrNotSupportCurve
}

// Get xi,Di,Ei,.......
func computeB(x []byte, D, E *ecpointgrouplaw.ECPoint) ([]byte, error) {
if !D.IsSameCurve(E) {
return nil, ecpointgrouplaw.ErrDifferentCurve
}
encodingD := ecpointEncoding(D)[:]
encodingE := ecpointEncoding(E)[:]
encodingD, err := ecpointEncoding(D)
if err != nil {
return nil, err
}

encodingE, err := ecpointEncoding(E)
if err != nil {
return nil, err
}
bMsg := &BMessage{
X: x,
D: encodingD,
E: encodingE,
D: encodingD[:],
E: encodingE[:],
}
result, err := proto.Marshal(bMsg)
if err != nil {
Expand All @@ -367,8 +453,9 @@ func computeRhoElli(x []byte, E *ecpointgrouplaw.ECPoint, message []byte, B []by
if err != nil {
return nil, err
}
bitUppBd := new(big.Int).Lsh(big1, uint(E.GetCurve().Params().N.BitLen()))
for j := 0; j < maxRetry; j++ {
tempMod := new(big.Int).Mod(temp, bit254)
tempMod := new(big.Int).Mod(temp, bitUppBd)
if utils.InRange(tempMod, big1, fieldOrder) == nil {
return tempMod, nil
}
Expand Down
29 changes: 25 additions & 4 deletions crypto/tss/eddsa/frost/signer/round_2.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"math/big"

"github.com/getamis/alice/crypto/ecpointgrouplaw"
"github.com/getamis/alice/crypto/elliptic"
"github.com/getamis/alice/types"
"github.com/getamis/sirius/log"
)
Expand Down Expand Up @@ -65,19 +66,26 @@ func (p *round2) HandleMessage(logger log.Logger, message types.Message) error {
func (p *round2) Finalize(logger log.Logger) (types.Handler, error) {
z := big.NewInt(0)
G := ecpointgrouplaw.NewBase(p.pubKey.GetCurve())

isREvenY := isYEven(p.r)
isPubKeyEvenY := isYEven(p.pubKey)
for _, node := range p.nodes {
// Calculate S
msgBody := node.GetMessage(types.MessageType(Type_Round2)).(*Message).GetRound2()
node.zi = new(big.Int).SetBytes(msgBody.Zi)
z.Add(z, node.zi)

// Calculate S
ziG := G.ScalarMult(node.zi)
ri := node.ri
if !isREvenY {
ri = ri.Neg()
}
YCopy := node.Y.Copy()
if !isPubKeyEvenY {
YCopy = YCopy.Neg()
}

cbi := new(big.Int).Mul(node.coBk, p.c)
cbi.Mod(cbi, p.pubKey.GetCurve().Params().N)
comparePart, err := node.Y.ScalarMult(cbi).Add(ri)
comparePart, err := YCopy.ScalarMult(cbi).Add(ri)
if err != nil {
logger.Debug("Failed to ScalarMult", "err", err)
return nil, err
Expand All @@ -93,3 +101,16 @@ func (p *round2) Finalize(logger log.Logger) (types.Handler, error) {
}
return nil, nil
}

// Different curves for Schnorr signature has different rules.
func isYEven(R *ecpointgrouplaw.ECPoint) bool {
curve := R.GetCurve()
switch curve {
// Taproot verification
case elliptic.Secp256k1():
return R.IsEvenY()
case elliptic.Ed25519():
return true
}
return true
}
Loading

0 comments on commit 09a854b

Please sign in to comment.