diff --git a/api/api.go b/api/api.go index 3c8798a..9c0aab5 100644 --- a/api/api.go +++ b/api/api.go @@ -28,15 +28,20 @@ package api import ( "context" + "errors" "fmt" + "io/fs" + "mime" "net/http" "strconv" + "time" connect "github.com/bufbuild/connect-go" grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/telekom/canary-bot/data" h "github.com/telekom/canary-bot/helper" + "github.com/telekom/canary-bot/proto/api/third_party" apiv1 "github.com/telekom/canary-bot/proto/api/v1" "github.com/telekom/canary-bot/proto/api/v1/apiv1connect" "go.uber.org/zap" @@ -117,12 +122,21 @@ func StartApi(data data.Database, config *Configuration, log *zap.SugaredLogger) interceptors := connect.WithInterceptors(a.NewAuthInterceptor()) mux := http.NewServeMux() - mux.Handle("/", getOpenAPIHandler()) + + // Open API Handler + Endpoint + openApiHandler, err := getOpenAPIHandler() + if err != nil { + log.Warn("Could not start the OpenAPI Endpoint ", err) + } else { + mux.Handle("/", openApiHandler) + } + mux.Handle(apiv1connect.NewApiServiceHandler(a, interceptors)) mux.Handle("/api/v1/", gwmux) server := &http.Server{ - Addr: addr, - Handler: h2c.NewHandler(mux, &http2.Server{}), + Addr: addr, + Handler: h2c.NewHandler(mux, &http2.Server{}), + ReadHeaderTimeout: time.Minute, } log.Info("Serving Connect, gRPC-Gateway and OpenAPI Documentation on ", addr) @@ -134,3 +148,16 @@ func StartApi(data data.Database, config *Configuration, log *zap.SugaredLogger) return server.ListenAndServe() } + +func getOpenAPIHandler() (http.Handler, error) { + err := mime.AddExtensionType(".svg", "image/svg+xml") + if err != nil { + return nil, errors.New("Couldn't add extension type: " + err.Error()) + } + // Use subdirectory in embedded files + subFS, err := fs.Sub(third_party.OpenAPI, "OpenAPI") + if err != nil { + return nil, errors.New("Couldn't create sub filesystem: " + err.Error()) + } + return http.FileServer(http.FS(subFS)), nil +} diff --git a/api/server.go b/api/server.go index 5343e47..da615a2 100644 --- a/api/server.go +++ b/api/server.go @@ -23,14 +23,10 @@ package api import ( "context" - "io/fs" - "mime" - "net/http" "time" "github.com/telekom/canary-bot/data" - third_party "github.com/telekom/canary-bot/proto/api/third_party" apiv1 "github.com/telekom/canary-bot/proto/api/v1" connect "github.com/bufbuild/connect-go" @@ -89,13 +85,3 @@ func (b *Api) ListNodes(ctx context.Context, req *connect.Request[apiv1.ListNode Nodes: nodes, }), nil } - -func getOpenAPIHandler() http.Handler { - mime.AddExtensionType(".svg", "image/svg+xml") - // Use subdirectory in embedded files - subFS, err := fs.Sub(third_party.OpenAPI, "OpenAPI") - if err != nil { - panic("couldn't create sub filesystem: " + err.Error()) - } - return http.FileServer(http.FS(subFS)) -} diff --git a/data/data.go b/data/data.go index 53ac1e0..238ef24 100644 --- a/data/data.go +++ b/data/data.go @@ -22,6 +22,7 @@ package data import ( + l "log" "strconv" "time" @@ -183,8 +184,13 @@ func (n *Node) Convert() *meshv1.Node { // Convert a given mesh node to a database node // with a given state of the node func Convert(n *meshv1.Node, state int) *Node { + id, err := h.Hash(n.Target) + if err != nil { + l.Printf("Could not get the hash value of the ID, please check the hash function") + } + return &Node{ - Id: h.Hash(n.Target), + Id: id, Name: n.Name, Target: n.Target, State: state, @@ -195,11 +201,20 @@ func Convert(n *meshv1.Node, state int) *Node { // Get the id of a database node. // The id is a hash integer func GetId(n *Node) uint32 { - return h.Hash(n.Target) + id, err := h.Hash(n.Target) + if err != nil { + l.Printf("Could not get the hash value of the ID, please check the hash function") + } + + return id } // Get the id of a given sample. // The id is a hash integer func GetSampleId(p *Sample) uint32 { - return h.Hash(p.From + p.To + strconv.FormatInt(p.Key, 10)) + id, err := h.Hash(p.From + p.To + strconv.FormatInt(p.Key, 10)) + if err != nil { + l.Printf("Could not get the hash value of the sample, please check the hash function") + } + return id } diff --git a/data/data_test.go b/data/data_test.go index 943cef4..db084a0 100644 --- a/data/data_test.go +++ b/data/data_test.go @@ -70,6 +70,10 @@ func Test_NewMemDb(t *testing.T) { } } +// helper function +func value(n uint32, _ error) uint32 { + return n +} func Test_Convert(t *testing.T) { tests := []struct { name string @@ -88,7 +92,7 @@ func Test_Convert(t *testing.T) { }, state: 1, expectedNode: &Node{ - Id: h.Hash("tegraT"), + Id: value(h.Hash("tegraT")), Name: "test", State: 1, Target: "tegraT", @@ -98,7 +102,7 @@ func Test_Convert(t *testing.T) { { name: "Node to MeshNode", inputNode: &Node{ - Id: h.Hash("tegraT"), + Id: value(h.Hash("tegraT")), Name: "test", State: 12, Target: "tegraT", @@ -142,12 +146,12 @@ func Test_GetId(t *testing.T) { { name: "Node with target", node: &Node{Target: "tegraT"}, - expectedId: h.Hash("tegraT"), + expectedId: value(h.Hash("tegraT")), }, { name: "Node without target", node: &Node{}, - expectedId: h.Hash(""), + expectedId: value(h.Hash("")), }, } @@ -174,12 +178,12 @@ func Test_GetSampleId(t *testing.T) { To: "Gose", Key: 1, }, - expectedId: h.Hash("EagleGose1"), + expectedId: value(h.Hash("EagleGose1")), }, { name: "Empty samples", sample: &Sample{}, - expectedId: h.Hash("0"), + expectedId: value(h.Hash("0")), }, } diff --git a/helper/helper.go b/helper/helper.go index cb52066..845b502 100644 --- a/helper/helper.go +++ b/helper/helper.go @@ -22,6 +22,7 @@ package helper import ( + "crypto/rand" "crypto/tls" "crypto/x509" "encoding/base64" @@ -30,10 +31,9 @@ import ( "hash/fnv" "io/ioutil" "log" - "math/rand" + "math/big" "net" "regexp" - "time" "google.golang.org/grpc/credentials" ) @@ -103,10 +103,13 @@ func ValidateAddress(domain string) bool { return RegExp.MatchString(domain) } -func Hash(s string) uint32 { +func Hash(s string) (uint32, error) { h := fnv.New32a() - h.Write([]byte(s)) - return h.Sum32() + _, err := h.Write([]byte(s)) + if err != nil { + return 0, errors.New("Generating a hash value failed: " + err.Error()) + } + return h.Sum32(), nil } // ------------------ @@ -114,19 +117,25 @@ const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "0123456789" -var seededRand *rand.Rand = rand.New( - rand.NewSource(time.Now().UnixNano())) - -func stringWithCharset(length int, charset string) string { - b := make([]byte, length) - for i := range b { - b[i] = charset[seededRand.Intn(len(charset))] +func stringWithCharset(n int64, chars string) (string, error) { + ret := make([]byte, n) + for i := int64(0); i < n; i++ { + num, err := rand.Int(rand.Reader, big.NewInt(int64(len(chars)))) + if err != nil { + return "", err + } + ret[i] = chars[num.Int64()] } - return string(b) + + return string(ret), nil } -func GenerateRandomToken(length int) string { - return stringWithCharset(length, charset) +func GenerateRandomToken(length int64) string { + token, err := stringWithCharset(length, charset) + if err != nil { + panic("Could not generate a random token, please check func GenerateRandomToken") + } + return token } //------------------ @@ -150,8 +159,12 @@ func LoadClientTLSCredentials(caCert_Paths []string, caCert_b64 []byte) (credent if len(caCert_Paths) > 0 { for _, path := range caCert_Paths { + /* #nosec G304*/ pemServerCA, err := ioutil.ReadFile(path) - if err != nil || !certPool.AppendCertsFromPEM(pemServerCA) { + if err != nil { + panic("Failed to add server ca certificate, path not found (security issue): " + path) + } + if !certPool.AppendCertsFromPEM(pemServerCA) { return nil, fmt.Errorf("Failed to add server ca certificate") } } @@ -167,7 +180,8 @@ func LoadClientTLSCredentials(caCert_Paths []string, caCert_b64 []byte) (credent // Create the credentials and return it config := &tls.Config{ - RootCAs: certPool, + RootCAs: certPool, + MinVersion: tls.VersionTLS12, } return credentials.NewTLS(config), nil @@ -200,6 +214,7 @@ func LoadServerTLSCredentials(serverCert_path string, serverKey_path string, ser config := &tls.Config{ Certificates: []tls.Certificate{serverCert}, ClientAuth: tls.NoClientCert, + MinVersion: tls.VersionTLS12, } return config, nil diff --git a/helper/helper_test.go b/helper/helper_test.go new file mode 100644 index 0000000..43f505d --- /dev/null +++ b/helper/helper_test.go @@ -0,0 +1,43 @@ +package helper + +import ( + "fmt" + "testing" +) + +func Test_stringWithCharset(t *testing.T) { + tests := []struct { + name string + length int64 + charset string + }{ + { + name: "normal string", + length: 32, + charset: charset, + }, + { + name: "0 string", + length: 0, + charset: charset, + }, + { + name: "0 charset", + length: 32, + charset: "1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + str, err := stringWithCharset(tt.length, tt.charset) + fmt.Printf("%v\n", str) + if err != nil { + t.Error("stringWithCharset with errors") + } + if len(str) != int(tt.length) { + t.Errorf("Length of string is not as eypexted: %v != %v", len(str), tt.length) + } + }) + } +} diff --git a/main.go b/main.go index b60a8a4..c5a080c 100644 --- a/main.go +++ b/main.go @@ -22,7 +22,7 @@ package main import ( - _ "net/http/pprof" + "log" "github.com/telekom/canary-bot/mesh" @@ -163,13 +163,19 @@ func bindEnvToFlags(cmd *cobra.Command, v *viper.Viper) { // Mapping Flag with "-" to uppercase env with "_" --listen-port to _LISTEN_PORT if strings.Contains(f.Name, "-") { envVarSuffix := strings.ToUpper(strings.ReplaceAll(f.Name, "-", "_")) - v.BindEnv(f.Name, envPrefix+"_"+envVarSuffix) + err := v.BindEnv(f.Name, envPrefix+"_"+envVarSuffix) + if err != nil { + log.Printf("Could not bind env varibale %v", envPrefix+"_"+envVarSuffix) + } } // Apply the viper config value to the flag when the flag is not set and viper has a value if !f.Changed && v.IsSet(f.Name) { val := v.GetString(f.Name) - cmd.Flags().Set(f.Name, fmt.Sprintf("%v", val)) + err := cmd.Flags().Set(f.Name, fmt.Sprintf("%v", val)) + if err != nil { + log.Printf("Could not apply viper config to flag: %v", v.GetString(f.Name)) + } } }) } diff --git a/mesh/client.go b/mesh/client.go index 43b3ce2..8753abf 100644 --- a/mesh/client.go +++ b/mesh/client.go @@ -231,12 +231,16 @@ func (m *Mesh) timeoutInterceptor( return err } -func (m *Mesh) closeClient(to *meshv1.Node) { +func (m *Mesh) closeClient(to *meshv1.Node) error { m.mu.Lock() - m.clients[GetId(to)].conn.Close() + err := m.clients[GetId(to)].conn.Close() + if err != nil { + return err + } // remove client delete(m.clients, GetId(to)) m.mu.Unlock() + return nil } func (m *Mesh) Rtt() { diff --git a/mesh/mesh.go b/mesh/mesh.go index 4f56cec..b6392e6 100644 --- a/mesh/mesh.go +++ b/mesh/mesh.go @@ -23,7 +23,6 @@ package mesh import ( "log" - "net/http" "strconv" "sync" "time" @@ -382,23 +381,31 @@ func (m *Mesh) retryPushSample(node *meshv1.Node) { // Get the ID of a node // Hash integer value of the target field (name of node) func GetId(n *meshv1.Node) uint32 { - return h.Hash(n.Target) + id, err := h.Hash(n.Target) + if err != nil { + log.Printf("Could not get the hash value of the sample, please check the hash function") + } + return id } // Get the ID of a sample // Hash integer value of the concatenated From, To and Key field func GetSampleId(p *meshv1.Sample) uint32 { - return h.Hash(p.From + p.To + strconv.FormatInt(p.Key, 10)) + id, err := h.Hash(p.From + p.To + strconv.FormatInt(p.Key, 10)) + if err != nil { + log.Printf("Could not get the hash value of the sample, please check the hash function") + } + return id } // Setup the Logger func getLogger(debug bool, pprofAddress string) *zap.SugaredLogger { if debug { // starting pprof for memory and cpu analysis - go func() { - log.Println("Starting go debugging profiler pprof on port 6060") - http.ListenAndServe(pprofAddress+":6060", nil) - }() + // go func() { + // log.Println("Starting go debugging profiler pprof on port 6060") + // http.ListenAndServe(pprofAddress+":6060", nil) + // }() // using debug logger logger, err := zap.NewDevelopment() diff --git a/mesh/server.go b/mesh/server.go index 6b7960d..6d48fd3 100644 --- a/mesh/server.go +++ b/mesh/server.go @@ -151,6 +151,9 @@ func (m *Mesh) StartServer() error { grpcServer := grpc.NewServer(opts...) meshv1.RegisterMeshServiceServer(grpcServer, meshServer) reflection.Register(grpcServer) - grpcServer.Serve(lis) + err = grpcServer.Serve(lis) + if err != nil { + return err + } return nil }