Skip to content

Commit

Permalink
Merge pull request #240 from batchcorp/blinktag/nats
Browse files Browse the repository at this point in the history
NATS improvements
  • Loading branch information
blinktag authored Feb 24, 2022
2 parents 9e63d7e + fe429a4 commit bea9e14
Show file tree
Hide file tree
Showing 10 changed files with 232 additions and 212 deletions.
2 changes: 1 addition & 1 deletion backends/mqtt/mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func generateTLSConfig(args *args.MQTTConn) (*tls.Config, error) {
if err != nil {
return nil, errors.Wrap(err, "unable to load ssl keypair")
}
} else {
} else if len(args.TlsOptions.TlsClientCert) > 0 {
// Server input
certpool.AppendCertsFromPEM(args.TlsOptions.TlsCaCert)

Expand Down
59 changes: 8 additions & 51 deletions backends/nats-jetstream/nats-jetstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ package nats_jetstream

import (
"context"
"crypto/tls"
"crypto/x509"
"io/ioutil"
"net/url"
"strings"

Expand All @@ -16,7 +13,6 @@ import (
"github.com/batchcorp/plumber/util"
"github.com/batchcorp/plumber/validate"

"github.com/batchcorp/plumber-schemas/build/go/protos/args"
"github.com/batchcorp/plumber-schemas/build/go/protos/opts"
)

Expand Down Expand Up @@ -64,11 +60,16 @@ func New(connOpts *opts.ConnectionOptions) (*NatsJetstream, error) {
}

var client *nats.Conn
if uri.Scheme == "tls" {
if uri.Scheme == "tls" || args.TlsOptions.UseTls {
// TLS Secured connection
tlsConfig, err := generateTLSConfig(args)
tlsConfig, err := util.GenerateTLSConfig(
args.TlsOptions.TlsCaCert,
args.TlsOptions.TlsClientCert,
args.TlsOptions.TlsClientKey,
args.TlsOptions.TlsSkipVerify,
)
if err != nil {
return nil, err
return nil, errors.Wrap(err, "Unable to generate TLS Config")
}

client, err = nats.Connect(args.Dsn, nats.Secure(tlsConfig), creds)
Expand Down Expand Up @@ -103,50 +104,6 @@ func (n *NatsJetstream) Test(_ context.Context) error {
return types.NotImplementedErr
}

func generateTLSConfig(args *args.NatsJetstreamConn) (*tls.Config, error) {
certpool := x509.NewCertPool()

var cert tls.Certificate
var err error

if util.FileExists(args.TlsOptions.TlsClientCert) {
// CLI input, read from file
pemCerts, err := ioutil.ReadFile(string(args.TlsOptions.TlsCaCert))
if err == nil {
certpool.AppendCertsFromPEM(pemCerts)
}

cert, err = tls.LoadX509KeyPair(string(args.TlsOptions.TlsClientCert), string(args.TlsOptions.TlsClientKey))
if err != nil {
return nil, errors.Wrap(err, "unable to load ssl keypair")
}

} else {
certpool.AppendCertsFromPEM(args.TlsOptions.TlsCaCert)

cert, err = tls.X509KeyPair(args.TlsOptions.TlsClientCert, args.TlsOptions.TlsClientKey)
if err != nil {
return nil, errors.Wrap(err, "unable to load ssl keypair")
}
}

// Just to print out the client certificate..
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return nil, errors.Wrap(err, "unable to parse certificate")
}

// Create tls.Config with desired tls properties
return &tls.Config{
RootCAs: certpool,
ClientAuth: tls.NoClientCert,
ClientCAs: nil,
InsecureSkipVerify: args.TlsOptions.TlsSkipVerify,
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
}, nil
}

func validateBaseConnOpts(connOpts *opts.ConnectionOptions) error {
if connOpts == nil {
return validate.ErrMissingConnOpts
Expand Down
60 changes: 9 additions & 51 deletions backends/nats-streaming/nats-streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ package nats_streaming

import (
"context"
"crypto/tls"
"crypto/x509"
"io/ioutil"
"net/url"
"strings"

Expand Down Expand Up @@ -67,11 +64,16 @@ func New(connOpts *opts.ConnectionOptions) (*NatsStreaming, error) {
}

var natsClient *nats.Conn
if uri.Scheme == "tls" {
if uri.Scheme == "tls" || args.TlsOptions.UseTls {
// TLS Secured connection
tlsConfig, err := generateTLSConfig(args)
tlsConfig, err := util.GenerateTLSConfig(
args.TlsOptions.TlsCaCert,
args.TlsOptions.TlsClientCert,
args.TlsOptions.TlsClientKey,
args.TlsOptions.TlsSkipVerify,
)
if err != nil {
return nil, err
return nil, errors.Wrap(err, "Unable to generate TLS Config")
}

natsClient, err = nats.Connect(args.Dsn, nats.Secure(tlsConfig), creds)
Expand All @@ -88,7 +90,7 @@ func New(connOpts *opts.ConnectionOptions) (*NatsStreaming, error) {

stanClient, err := stan.Connect(args.ClusterId, args.ClientId, stan.NatsOptions())
if err != nil {
return nil, errors.Wrap(err, "could not create NATS subscription")
return nil, errors.Wrap(err, "could not create STAN subscription")
}

return &NatsStreaming{
Expand Down Expand Up @@ -118,50 +120,6 @@ func (n *NatsStreaming) Test(_ context.Context) error {
return types.NotImplementedErr
}

func generateTLSConfig(args *args.NatsStreamingConn) (*tls.Config, error) {
certpool := x509.NewCertPool()

var cert tls.Certificate
var err error

if util.FileExists(args.TlsOptions.TlsClientCert) {
// CLI input, read from file
pemCerts, err := ioutil.ReadFile(string(args.TlsOptions.TlsCaCert))
if err == nil {
certpool.AppendCertsFromPEM(pemCerts)
}

cert, err = tls.LoadX509KeyPair(string(args.TlsOptions.TlsClientCert), string(args.TlsOptions.TlsClientKey))
if err != nil {
return nil, errors.Wrap(err, "unable to load ssl keypair")
}

} else {
certpool.AppendCertsFromPEM(args.TlsOptions.TlsCaCert)

cert, err = tls.X509KeyPair(args.TlsOptions.TlsClientCert, args.TlsOptions.TlsClientKey)
if err != nil {
return nil, errors.Wrap(err, "unable to load ssl keypair")
}
}

// Just to print out the client certificate..
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return nil, errors.Wrap(err, "unable to parse certificate")
}

// Create tls.Config with desired tls properties
return &tls.Config{
RootCAs: certpool,
ClientAuth: tls.NoClientCert,
ClientCAs: nil,
InsecureSkipVerify: args.TlsOptions.TlsSkipVerify,
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
}, nil
}

func validateBaseConnOpts(connOpts *opts.ConnectionOptions) error {
if connOpts == nil {
return validate.ErrMissingConnOpts
Expand Down
60 changes: 0 additions & 60 deletions backends/nats-streaming/nats-streaming_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package nats_streaming

import (
"io/ioutil"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"

Expand Down Expand Up @@ -62,64 +60,6 @@ var _ = Describe("Nats Streaming Backend", func() {
})
})

Context("generateTLSConfig", func() {
It("works with files", func() {
tlsConfig, err := generateTLSConfig(connOpts.GetNatsStreaming())
Expect(err).ToNot(HaveOccurred())
Expect(len(tlsConfig.Certificates)).To(Equal(1))
})
It("returns error on incorrect cert file", func() {
args := connOpts.GetNatsStreaming()
args.TlsOptions.TlsClientCert = args.TlsOptions.TlsClientKey
_, err := generateTLSConfig(args)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("unable to load ssl keypair"))
})
It("returns error on incorrect cert string", func() {
caBytes, err := ioutil.ReadFile("../../test-assets/ssl/ca.crt")
Expect(err).ToNot(HaveOccurred())
certBytes, err := ioutil.ReadFile("../../test-assets/ssl/client.crt")
Expect(err).ToNot(HaveOccurred())
keyBytes, err := ioutil.ReadFile("../../test-assets/ssl/client.key")
Expect(err).ToNot(HaveOccurred())

args := &args.NatsStreamingConn{
TlsOptions: &args.NatsStreamingTLSOptions{
TlsCaCert: caBytes,
TlsClientCert: keyBytes,
TlsClientKey: certBytes,
TlsSkipVerify: true,
},
}

_, err = generateTLSConfig(args)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("unable to load ssl keypair"))
})

It("works with strings", func() {
caBytes, err := ioutil.ReadFile("../../test-assets/ssl/ca.crt")
Expect(err).ToNot(HaveOccurred())
certBytes, err := ioutil.ReadFile("../../test-assets/ssl/client.crt")
Expect(err).ToNot(HaveOccurred())
keyBytes, err := ioutil.ReadFile("../../test-assets/ssl/client.key")
Expect(err).ToNot(HaveOccurred())

args := &args.NatsStreamingConn{
TlsOptions: &args.NatsStreamingTLSOptions{
TlsCaCert: caBytes,
TlsClientCert: certBytes,
TlsClientKey: keyBytes,
TlsSkipVerify: true,
},
}

tlsConfig, err := generateTLSConfig(args)
Expect(err).ToNot(HaveOccurred())
Expect(len(tlsConfig.Certificates)).To(Equal(1))
})
})

Context("validateBaseConnOpts", func() {
It("validates conn presence", func() {
err := validateBaseConnOpts(nil)
Expand Down
46 changes: 9 additions & 37 deletions backends/nats/nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ package nats

import (
"context"
"crypto/tls"
"crypto/x509"
"io/ioutil"
"net/url"

"github.com/nats-io/nats.go"
Expand All @@ -15,6 +12,7 @@ import (
"github.com/batchcorp/plumber-schemas/build/go/protos/opts"

"github.com/batchcorp/plumber/types"
"github.com/batchcorp/plumber/util"
)

const BackendName = "nats"
Expand Down Expand Up @@ -73,7 +71,7 @@ func newClient(opts *args.NatsConn) (*nats.Conn, error) {
creds = nats.UserCredentials(string(opts.UserCredentials))
}

if uri.Scheme != "tls" {
if uri.Scheme != "tls" && !opts.TlsOptions.UseTls {
// Insecure connection
c, err := nats.Connect(opts.Dsn, creds)
if err != nil {
Expand All @@ -83,9 +81,14 @@ func newClient(opts *args.NatsConn) (*nats.Conn, error) {
}

// TLS Secured connection
tlsConfig, err := generateTLSConfig(opts)
tlsConfig, err := util.GenerateTLSConfig(
opts.TlsOptions.TlsCaCert,
opts.TlsOptions.TlsClientCert,
opts.TlsOptions.TlsClientKey,
opts.TlsOptions.TlsSkipVerify,
)
if err != nil {
return nil, err
return nil, errors.Wrap(err, "Unable to generate TLS Config")
}

c, err := nats.Connect(opts.Dsn, nats.Secure(tlsConfig), creds)
Expand All @@ -95,34 +98,3 @@ func newClient(opts *args.NatsConn) (*nats.Conn, error) {

return c, nil
}

func generateTLSConfig(opts *args.NatsConn) (*tls.Config, error) {
certpool := x509.NewCertPool()

pemCerts, err := ioutil.ReadFile(string(opts.TlsOptions.TlsCaCert))
if err == nil {
certpool.AppendCertsFromPEM(pemCerts)
}

// Import client certificate/key pair
cert, err := tls.LoadX509KeyPair(string(opts.TlsOptions.TlsClientCert), string(opts.TlsOptions.TlsClientKey))
if err != nil {
return nil, errors.Wrap(err, "unable to load ssl keypair")
}

// Just to print out the client certificate..
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return nil, errors.Wrap(err, "unable to parse certificate")
}

// Create tls.Config with desired tls properties
return &tls.Config{
RootCAs: certpool,
ClientAuth: tls.NoClientCert,
ClientCAs: nil,
InsecureSkipVerify: opts.TlsOptions.TlsSkipVerify,
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
}, nil
}
Loading

0 comments on commit bea9e14

Please sign in to comment.