diff --git a/cert/cert.go b/cert/cert.go index d11025843..73c88cd01 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -143,6 +143,7 @@ func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, cu var err error switch v { + // Implementations must ensure the result is a valid cert! case VersionPre1, Version1: c, err = unmarshalCertificateV1(b, publicKey) case Version2: diff --git a/cert/cert_v1.go b/cert/cert_v1.go index b807f8d21..6bb146fea 100644 --- a/cert/cert_v1.go +++ b/cert/cert_v1.go @@ -317,6 +317,58 @@ func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error { issuer: t.issuer, } + return c.validate() +} + +func (c *certificateV1) validate() error { + // Empty names are allowed + + if len(c.details.publicKey) == 0 { + return ErrInvalidPublicKey + } + + // Original v1 rules allowed multiple networks to be present but ignored all but the first one. + // Continue to allow this behavior + if !c.details.isCA && len(c.details.networks) == 0 { + return NewErrInvalidCertificateProperties("non-CA certificates must contain exactly one network") + } + + for _, network := range c.details.networks { + if !network.IsValid() || !network.Addr().IsValid() { + return NewErrInvalidCertificateProperties("invalid network: %s", network) + } + + if network.Addr().Is6() { + return NewErrInvalidCertificateProperties("certificate may not contain IPv6 networks: %v", network) + } + + if network.Addr().IsUnspecified() { + return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network) + } + + if network.Addr().Zone() != "" { + return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network) + } + } + + for _, network := range c.details.unsafeNetworks { + if !network.IsValid() || !network.Addr().IsValid() { + return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network) + } + + if network.Addr().Is6() { + return NewErrInvalidCertificateProperties("certificate may not contain IPv6 unsafe networks: %v", network) + } + + if network.Addr().Zone() != "" { + return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network) + } + } + + // v1 doesn't bother with sort order or uniqueness of networks or unsafe networks. + // We can't modify the unmarshalled data because verification requires re-marshalling and a re-ordered + // unsafe networks would result in a different signature. + return nil } @@ -404,6 +456,11 @@ func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) } } + err = nc.validate() + if err != nil { + return nil, err + } + return &nc, nil } diff --git a/cert/cert_v2.go b/cert/cert_v2.go index dce929684..322463e99 100644 --- a/cert/cert_v2.go +++ b/cert/cert_v2.go @@ -65,8 +65,8 @@ type certificateV2 struct { type detailsV2 struct { name string - networks []netip.Prefix - unsafeNetworks []netip.Prefix + networks []netip.Prefix // MUST BE SORTED + unsafeNetworks []netip.Prefix // MUST BE SORTED groups []string isCA bool notBefore time.Time @@ -376,6 +376,77 @@ func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error { } c.curve = t.Curve c.publicKey = t.PublicKey + return c.validate() +} + +func (c *certificateV2) validate() error { + // Empty names are allowed + + if len(c.publicKey) == 0 { + return ErrInvalidPublicKey + } + + if !c.details.isCA && len(c.details.networks) == 0 { + return NewErrInvalidCertificateProperties("non-CA certificate must contain at least 1 network") + } + + hasV4Networks := false + hasV6Networks := false + for _, network := range c.details.networks { + if !network.IsValid() || !network.Addr().IsValid() { + return NewErrInvalidCertificateProperties("invalid network: %s", network) + } + + if network.Addr().IsUnspecified() { + return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network) + } + + if network.Addr().Zone() != "" { + return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network) + } + + if network.Addr().Is4In6() { + return NewErrInvalidCertificateProperties("4in6 networks are not allowed: %s", network) + } + + hasV4Networks = hasV4Networks || network.Addr().Is4() + hasV6Networks = hasV6Networks || network.Addr().Is6() + } + + slices.SortFunc(c.details.networks, comparePrefix) + err := findDuplicatePrefix(c.details.networks) + if err != nil { + return err + } + + for _, network := range c.details.unsafeNetworks { + if !network.IsValid() || !network.Addr().IsValid() { + return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network) + } + + if network.Addr().Zone() != "" { + return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network) + } + + if !c.details.isCA { + if network.Addr().Is6() { + if !hasV6Networks { + return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network) + } + } else if network.Addr().Is4() { + if !hasV4Networks { + return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network) + } + } + } + } + + slices.SortFunc(c.details.unsafeNetworks, comparePrefix) + err = findDuplicatePrefix(c.details.unsafeNetworks) + if err != nil { + return err + } + return nil } @@ -536,13 +607,20 @@ func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certifica return nil, err } - return &certificateV2{ + c := &certificateV2{ details: details, rawDetails: rawDetails, curve: curve, publicKey: rawPublicKey, signature: rawSignature, - }, nil + } + + err = c.validate() + if err != nil { + return nil, err + } + + return c, nil } func unmarshalDetails(b cryptobyte.String) (detailsV2, error) { @@ -639,9 +717,6 @@ func unmarshalDetails(b cryptobyte.String) (detailsV2, error) { return detailsV2{}, ErrBadFormat } - slices.SortFunc(networks, comparePrefix) - slices.SortFunc(unsafeNetworks, comparePrefix) - return detailsV2{ name: string(name), networks: networks, diff --git a/cert/errors.go b/cert/errors.go index 60273a99d..4bbc023ad 100644 --- a/cert/errors.go +++ b/cert/errors.go @@ -2,6 +2,7 @@ package cert import ( "errors" + "fmt" ) var ( @@ -17,10 +18,9 @@ var ( ErrInvalidPrivateKey = errors.New("invalid private key") ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve") ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair") + ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") ErrCaNotFound = errors.New("could not find ca for the certificate") - ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") - ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block") ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner") ErrInvalidPEMX25519PublicKeyBanner = errors.New("bytes did not contain a proper X25519 public key banner") @@ -35,3 +35,15 @@ var ( ErrEmptySignature = errors.New("empty signature") ErrEmptyRawDetails = errors.New("empty rawDetails not allowed") ) + +type ErrInvalidCertificateProperties struct { + str string +} + +func NewErrInvalidCertificateProperties(format string, a ...any) error { + return &ErrInvalidCertificateProperties{fmt.Sprintf(format, a...)} +} + +func (e *ErrInvalidCertificateProperties) Error() string { + return e.str +} diff --git a/cert/helper_test.go b/cert/helper_test.go index 05142dd54..1b72a0ffd 100644 --- a/cert/helper_test.go +++ b/cert/helper_test.go @@ -77,6 +77,10 @@ func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string after = time.Now().Add(time.Second * 60).Round(time.Second) } + if len(networks) == 0 { + networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")} + } + var pub, priv []byte switch curve { case Curve_CURVE25519: diff --git a/cert/pem.go b/cert/pem.go index 249b63917..7ad28d129 100644 --- a/cert/pem.go +++ b/cert/pem.go @@ -34,6 +34,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { var err error switch p.Type { + // Implementations must validate the resulting certificate contains valid information case CertificateBanner: c, err = unmarshalCertificateV1(p.Bytes, nil) case CertificateV2Banner: diff --git a/cert/sign.go b/cert/sign.go index a1e09cd2b..12d4ee459 100644 --- a/cert/sign.go +++ b/cert/sign.go @@ -9,7 +9,6 @@ import ( "fmt" "math/big" "net/netip" - "slices" "time" ) @@ -31,6 +30,7 @@ type TBSCertificate struct { type beingSignedCertificate interface { // fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation + // Implementations must validate the resulting certificate contains valid information fromTBSCertificate(*TBSCertificate) error // marshalForSigning returns the bytes that should be signed @@ -83,9 +83,6 @@ func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLamb return nil, fmt.Errorf("curve in cert and private key supplied don't match") } - //TODO: make sure we have all minimum properties to sign, like a public key - //TODO: we need to verify networks and unsafe networks (no duplicates, max of 1 of each version for v2 certs - if signer != nil { if t.IsCA { return nil, fmt.Errorf("can not sign a CA certificate with another") @@ -107,9 +104,6 @@ func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLamb } } - slices.SortFunc(t.Networks, comparePrefix) - slices.SortFunc(t.UnsafeNetworks, comparePrefix) - var c beingSignedCertificate switch t.Version { case Version1: @@ -158,3 +152,16 @@ func comparePrefix(a, b netip.Prefix) int { } return addr } + +// findDuplicatePrefix returns an error if there is a duplicate prefix in the pre-sorted input slice sortedPrefixes +func findDuplicatePrefix(sortedPrefixes []netip.Prefix) error { + if len(sortedPrefixes) < 2 { + return nil + } + for i := 1; i < len(sortedPrefixes); i++ { + if comparePrefix(sortedPrefixes[i], sortedPrefixes[i-1]) == 0 { + return NewErrInvalidCertificateProperties("duplicate network detected: %v", sortedPrefixes[i]) + } + } + return nil +} diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index a6737543b..86795e43d 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -73,7 +73,7 @@ func Test_printCert(t *testing.T) { tf.Truncate(0) tf.Seek(0, 0) ca, caKey := NewTestCaCert("test ca", nil, nil, time.Time{}, time.Time{}, nil, nil, nil) - c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, []string{"hi"}) + c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}, nil, []string{"hi"}) p, _ := c.MarshalPEM() tf.Write(p) @@ -97,7 +97,9 @@ func Test_printCert(t *testing.T) { "isCa": false, "issuer": "`+c.Issuer()+`", "name": "test", - "networks": [], + "networks": [ + "10.0.0.123/8" + ], "notAfter": "0001-01-01T00:00:00Z", "notBefore": "0001-01-01T00:00:00Z", "publicKey": "`+pk+`", @@ -116,7 +118,9 @@ func Test_printCert(t *testing.T) { "isCa": false, "issuer": "`+c.Issuer()+`", "name": "test", - "networks": [], + "networks": [ + "10.0.0.123/8" + ], "notAfter": "0001-01-01T00:00:00Z", "notBefore": "0001-01-01T00:00:00Z", "publicKey": "`+pk+`", @@ -135,7 +139,9 @@ func Test_printCert(t *testing.T) { "isCa": false, "issuer": "`+c.Issuer()+`", "name": "test", - "networks": [], + "networks": [ + "10.0.0.123/8" + ], "notAfter": "0001-01-01T00:00:00Z", "notBefore": "0001-01-01T00:00:00Z", "publicKey": "`+pk+`", @@ -166,7 +172,7 @@ func Test_printCert(t *testing.T) { assert.Nil(t, err) assert.Equal( t, - `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}] + `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}] `, ob.String(), ) @@ -212,6 +218,10 @@ func NewTestCert(ca cert.Certificate, signerKey []byte, name string, before, aft after = ca.NotAfter() } + if len(networks) == 0 { + networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")} + } + pub, rawPriv := x25519Keypair() nc := &cert.TBSCertificate{ Version: cert.Version1,