diff --git a/states/etcd_connect.go b/states/etcd_connect.go index 2f50a69..8cff09d 100644 --- a/states/etcd_connect.go +++ b/states/etcd_connect.go @@ -2,12 +2,15 @@ package states import ( "context" - "errors" + "crypto/tls" + "crypto/x509" "fmt" + "os" "path" "strings" "time" + "github.com/cockroachdb/errors" "github.com/milvus-io/birdwatcher/configs" "github.com/milvus-io/birdwatcher/framework" clientv3 "go.etcd.io/etcd/client/v3" @@ -37,15 +40,21 @@ func pingEtcd(ctx context.Context, cli clientv3.KV, rootPath string, metaPath st } func (app *ApplicationState) ConnectCommand(ctx context.Context, cp *ConnectParams) error { - etcdCli, err := clientv3.New(clientv3.Config{ + tls, err := app.getTLSConfig(cp) + if err != nil { + return err + } + + cfg := clientv3.Config{ Endpoints: []string{cp.EtcdAddr}, DialTimeout: time.Second * 10, + TLS: tls, // disable grpc logging Logger: zap.NewNop(), - }) + } + etcdCli, err := clientv3.New(cfg) if err != nil { - fmt.Println(err.Error()) return err } @@ -87,6 +96,63 @@ type ConnectParams struct { MetaPath string `name:"metaPath" default:"meta" desc:"meta path prefix"` Force bool `name:"force" default:"false" desc:"force connect ignoring ping Etcd & rootPath check"` Dry bool `name:"dry" default:"false" desc:"dry connect without specifying milvus instance"` + UseSSL bool `name:"use_ssl" default:"false" desc:"enable to use SSL"` + EnableTLS bool `name:"enableTLS" default:"false" desc:"use TLS"` + RootCA string `name:"rootCAPem" default:"" desc:"root CA pem file path"` + ETCDPem string `name:"etcdCert" default:"" desc:"etcd tls cert file path"` + ETCDKey string `name:"etcdKey" default:"" desc:"etcd tls key file path"` + TLSMinVersion string `name:"min_version" default:"1.2" desc:"TLS min version"` +} + +func (app *ApplicationState) getTLSConfig(cp *ConnectParams) (*tls.Config, error) { + if !cp.EnableTLS { + return nil, nil + } + + var tlsMinVersion uint16 + switch cp.TLSMinVersion { + case "1.0": + tlsMinVersion = tls.VersionTLS10 + case "1.1": + tlsMinVersion = tls.VersionTLS11 + case "1.2": + tlsMinVersion = tls.VersionTLS12 + case "1.3": + tlsMinVersion = tls.VersionTLS13 + default: + return nil, errors.New("invalid min tls version, only 1.0, 1.1, 1.2 and 1.3 is supported") + } + + rootCertPool, err := x509.SystemCertPool() + if err != nil { + return nil, err + } + + if cp.RootCA != "" { + bs, err := os.ReadFile(cp.RootCA) + if err != nil { + return nil, err + } + + ok := rootCertPool.AppendCertsFromPEM(bs) + if !ok { + return nil, errors.New("Root CA PEM cannot be parsed") + } + } + + cert, err := tls.LoadX509KeyPair(cp.ETCDPem, cp.ETCDKey) + if err != nil { + return nil, errors.Wrap(err, "failed to load etcd cert/key pair") + } + + // #nosec G402 + return &tls.Config{ + RootCAs: rootCertPool, + Certificates: []tls.Certificate{ + cert, + }, + MinVersion: tlsMinVersion, + }, nil } type etcdConnectedState struct {