Skip to content

Commit

Permalink
enforce certificate correctness in TBSCertificate.SignWith (#1266)
Browse files Browse the repository at this point in the history
* enforce certificate correctness in TBSCertificate.SignWith

* check length, not nil

* Address review comments

* github hates me

---------

Co-authored-by: Nate Brown <[email protected]>
Co-authored-by: Jack Doan <[email protected]>
  • Loading branch information
3 people authored Jan 9, 2025
1 parent 3f31517 commit 8704047
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 21 deletions.
1 change: 1 addition & 0 deletions cert/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, cu
var err error

switch v {
// Implementations must ensure the result is a valid cert!
case VersionPre1, Version1:
c, err = unmarshalCertificateV1(b, publicKey)
case Version2:
Expand Down
57 changes: 57 additions & 0 deletions cert/cert_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,58 @@ func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error {
issuer: t.issuer,
}

return c.validate()
}

func (c *certificateV1) validate() error {
// Empty names are allowed

if len(c.details.publicKey) == 0 {
return ErrInvalidPublicKey
}

// Original v1 rules allowed multiple networks to be present but ignored all but the first one.
// Continue to allow this behavior
if !c.details.isCA && len(c.details.networks) == 0 {
return NewErrInvalidCertificateProperties("non-CA certificates must contain exactly one network")
}

for _, network := range c.details.networks {
if !network.IsValid() || !network.Addr().IsValid() {
return NewErrInvalidCertificateProperties("invalid network: %s", network)
}

if network.Addr().Is6() {
return NewErrInvalidCertificateProperties("certificate may not contain IPv6 networks: %v", network)
}

if network.Addr().IsUnspecified() {
return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
}

if network.Addr().Zone() != "" {
return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
}
}

for _, network := range c.details.unsafeNetworks {
if !network.IsValid() || !network.Addr().IsValid() {
return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
}

if network.Addr().Is6() {
return NewErrInvalidCertificateProperties("certificate may not contain IPv6 unsafe networks: %v", network)
}

if network.Addr().Zone() != "" {
return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
}
}

// v1 doesn't bother with sort order or uniqueness of networks or unsafe networks.
// We can't modify the unmarshalled data because verification requires re-marshalling and a re-ordered
// unsafe networks would result in a different signature.

return nil
}

Expand Down Expand Up @@ -404,6 +456,11 @@ func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error)
}
}

err = nc.validate()
if err != nil {
return nil, err
}

return &nc, nil
}

Expand Down
89 changes: 82 additions & 7 deletions cert/cert_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ type certificateV2 struct {

type detailsV2 struct {
name string
networks []netip.Prefix
unsafeNetworks []netip.Prefix
networks []netip.Prefix // MUST BE SORTED
unsafeNetworks []netip.Prefix // MUST BE SORTED
groups []string
isCA bool
notBefore time.Time
Expand Down Expand Up @@ -376,6 +376,77 @@ func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error {
}
c.curve = t.Curve
c.publicKey = t.PublicKey
return c.validate()
}

func (c *certificateV2) validate() error {
// Empty names are allowed

if len(c.publicKey) == 0 {
return ErrInvalidPublicKey
}

if !c.details.isCA && len(c.details.networks) == 0 {
return NewErrInvalidCertificateProperties("non-CA certificate must contain at least 1 network")
}

hasV4Networks := false
hasV6Networks := false
for _, network := range c.details.networks {
if !network.IsValid() || !network.Addr().IsValid() {
return NewErrInvalidCertificateProperties("invalid network: %s", network)
}

if network.Addr().IsUnspecified() {
return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
}

if network.Addr().Zone() != "" {
return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
}

if network.Addr().Is4In6() {
return NewErrInvalidCertificateProperties("4in6 networks are not allowed: %s", network)
}

hasV4Networks = hasV4Networks || network.Addr().Is4()
hasV6Networks = hasV6Networks || network.Addr().Is6()
}

slices.SortFunc(c.details.networks, comparePrefix)
err := findDuplicatePrefix(c.details.networks)
if err != nil {
return err
}

for _, network := range c.details.unsafeNetworks {
if !network.IsValid() || !network.Addr().IsValid() {
return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
}

if network.Addr().Zone() != "" {
return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
}

if !c.details.isCA {
if network.Addr().Is6() {
if !hasV6Networks {
return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network)
}
} else if network.Addr().Is4() {
if !hasV4Networks {
return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
}
}
}
}

slices.SortFunc(c.details.unsafeNetworks, comparePrefix)
err = findDuplicatePrefix(c.details.unsafeNetworks)
if err != nil {
return err
}

return nil
}

Expand Down Expand Up @@ -536,13 +607,20 @@ func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certifica
return nil, err
}

return &certificateV2{
c := &certificateV2{
details: details,
rawDetails: rawDetails,
curve: curve,
publicKey: rawPublicKey,
signature: rawSignature,
}, nil
}

err = c.validate()
if err != nil {
return nil, err
}

return c, nil
}

func unmarshalDetails(b cryptobyte.String) (detailsV2, error) {
Expand Down Expand Up @@ -639,9 +717,6 @@ func unmarshalDetails(b cryptobyte.String) (detailsV2, error) {
return detailsV2{}, ErrBadFormat
}

slices.SortFunc(networks, comparePrefix)
slices.SortFunc(unsafeNetworks, comparePrefix)

return detailsV2{
name: string(name),
networks: networks,
Expand Down
16 changes: 14 additions & 2 deletions cert/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cert

import (
"errors"
"fmt"
)

var (
Expand All @@ -17,10 +18,9 @@ var (
ErrInvalidPrivateKey = errors.New("invalid private key")
ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve")
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
ErrCaNotFound = errors.New("could not find ca for the certificate")

ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")

ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")
ErrInvalidPEMX25519PublicKeyBanner = errors.New("bytes did not contain a proper X25519 public key banner")
Expand All @@ -35,3 +35,15 @@ var (
ErrEmptySignature = errors.New("empty signature")
ErrEmptyRawDetails = errors.New("empty rawDetails not allowed")
)

type ErrInvalidCertificateProperties struct {
str string
}

func NewErrInvalidCertificateProperties(format string, a ...any) error {
return &ErrInvalidCertificateProperties{fmt.Sprintf(format, a...)}
}

func (e *ErrInvalidCertificateProperties) Error() string {
return e.str
}
4 changes: 4 additions & 0 deletions cert/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string
after = time.Now().Add(time.Second * 60).Round(time.Second)
}

if len(networks) == 0 {
networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
}

var pub, priv []byte
switch curve {
case Curve_CURVE25519:
Expand Down
1 change: 1 addition & 0 deletions cert/pem.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
var err error

switch p.Type {
// Implementations must validate the resulting certificate contains valid information
case CertificateBanner:
c, err = unmarshalCertificateV1(p.Bytes, nil)
case CertificateV2Banner:
Expand Down
21 changes: 14 additions & 7 deletions cert/sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"fmt"
"math/big"
"net/netip"
"slices"
"time"
)

Expand All @@ -31,6 +30,7 @@ type TBSCertificate struct {

type beingSignedCertificate interface {
// fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation
// Implementations must validate the resulting certificate contains valid information
fromTBSCertificate(*TBSCertificate) error

// marshalForSigning returns the bytes that should be signed
Expand Down Expand Up @@ -83,9 +83,6 @@ func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLamb
return nil, fmt.Errorf("curve in cert and private key supplied don't match")
}

//TODO: make sure we have all minimum properties to sign, like a public key
//TODO: we need to verify networks and unsafe networks (no duplicates, max of 1 of each version for v2 certs

if signer != nil {
if t.IsCA {
return nil, fmt.Errorf("can not sign a CA certificate with another")
Expand All @@ -107,9 +104,6 @@ func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLamb
}
}

slices.SortFunc(t.Networks, comparePrefix)
slices.SortFunc(t.UnsafeNetworks, comparePrefix)

var c beingSignedCertificate
switch t.Version {
case Version1:
Expand Down Expand Up @@ -158,3 +152,16 @@ func comparePrefix(a, b netip.Prefix) int {
}
return addr
}

// findDuplicatePrefix returns an error if there is a duplicate prefix in the pre-sorted input slice sortedPrefixes
func findDuplicatePrefix(sortedPrefixes []netip.Prefix) error {
if len(sortedPrefixes) < 2 {
return nil
}
for i := 1; i < len(sortedPrefixes); i++ {
if comparePrefix(sortedPrefixes[i], sortedPrefixes[i-1]) == 0 {
return NewErrInvalidCertificateProperties("duplicate network detected: %v", sortedPrefixes[i])
}
}
return nil
}
20 changes: 15 additions & 5 deletions cmd/nebula-cert/print_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func Test_printCert(t *testing.T) {
tf.Truncate(0)
tf.Seek(0, 0)
ca, caKey := NewTestCaCert("test ca", nil, nil, time.Time{}, time.Time{}, nil, nil, nil)
c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, []string{"hi"})
c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}, nil, []string{"hi"})

p, _ := c.MarshalPEM()
tf.Write(p)
Expand All @@ -97,7 +97,9 @@ func Test_printCert(t *testing.T) {
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [],
"networks": [
"10.0.0.123/8"
],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
Expand All @@ -116,7 +118,9 @@ func Test_printCert(t *testing.T) {
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [],
"networks": [
"10.0.0.123/8"
],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
Expand All @@ -135,7 +139,9 @@ func Test_printCert(t *testing.T) {
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [],
"networks": [
"10.0.0.123/8"
],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
Expand Down Expand Up @@ -166,7 +172,7 @@ func Test_printCert(t *testing.T) {
assert.Nil(t, err)
assert.Equal(
t,
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
`,
ob.String(),
)
Expand Down Expand Up @@ -212,6 +218,10 @@ func NewTestCert(ca cert.Certificate, signerKey []byte, name string, before, aft
after = ca.NotAfter()
}

if len(networks) == 0 {
networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
}

pub, rawPriv := x25519Keypair()
nc := &cert.TBSCertificate{
Version: cert.Version1,
Expand Down

0 comments on commit 8704047

Please sign in to comment.