diff --git a/controllers/complianceeventsapi/migrations/000001_compliance_history_initial_tables.up.sql b/controllers/complianceeventsapi/migrations/000001_compliance_history_initial_tables.up.sql index cc713ca4..08de69a8 100644 --- a/controllers/complianceeventsapi/migrations/000001_compliance_history_initial_tables.up.sql +++ b/controllers/complianceeventsapi/migrations/000001_compliance_history_initial_tables.up.sql @@ -21,17 +21,17 @@ CREATE TABLE IF NOT EXISTS parent_policies( -- This is required until we only support Postgres 15+ to utilize NULLS NOT DISTINCT. -- Partial indexes with 1 nullable unique field provided (e.g. A, B, C) -CREATE UNIQUE INDEX parent_policies_null1 ON parent_policies (name, namespace, controls, standards) WHERE categories IS NULL; -CREATE UNIQUE INDEX parent_policies_null2 ON parent_policies (name, namespace, categories, standards) WHERE controls IS NULL; -CREATE UNIQUE INDEX parent_policies_null3 ON parent_policies (name, namespace, categories, controls) WHERE standards IS NULL; +CREATE UNIQUE INDEX IF NOT EXISTS parent_policies_null1 ON parent_policies (name, namespace, controls, standards) WHERE categories IS NULL; +CREATE UNIQUE INDEX IF NOT EXISTS parent_policies_null2 ON parent_policies (name, namespace, categories, standards) WHERE controls IS NULL; +CREATE UNIQUE INDEX IF NOT EXISTS parent_policies_null3 ON parent_policies (name, namespace, categories, controls) WHERE standards IS NULL; -- Partial indexes with 2 nullable unique field provided (e.g. AB AC BC) -CREATE UNIQUE INDEX parent_policies_null4 ON parent_policies (name, namespace, standards) WHERE categories IS NULL AND controls IS NULL; -CREATE UNIQUE INDEX parent_policies_null5 ON parent_policies (name, namespace, controls) WHERE categories IS NULL AND standards IS NULL; -CREATE UNIQUE INDEX parent_policies_null6 ON parent_policies (name, namespace, categories) WHERE controls IS NULL AND standards IS NULL; +CREATE UNIQUE INDEX IF NOT EXISTS parent_policies_null4 ON parent_policies (name, namespace, standards) WHERE categories IS NULL AND controls IS NULL; +CREATE UNIQUE INDEX IF NOT EXISTS parent_policies_null5 ON parent_policies (name, namespace, controls) WHERE categories IS NULL AND standards IS NULL; +CREATE UNIQUE INDEX IF NOT EXISTS parent_policies_null6 ON parent_policies (name, namespace, categories) WHERE controls IS NULL AND standards IS NULL; -- Partial index with no nullable unique fields provided (e.g. ABC) -CREATE UNIQUE INDEX parent_policies_null7 ON parent_policies (name, namespace) WHERE categories IS NULL AND controls IS NULL AND standards IS NULL; +CREATE UNIQUE INDEX IF NOT EXISTS parent_policies_null7 ON parent_policies (name, namespace) WHERE categories IS NULL AND controls IS NULL AND standards IS NULL; CREATE TABLE IF NOT EXISTS policies( id serial PRIMARY KEY, @@ -39,22 +39,20 @@ CREATE TABLE IF NOT EXISTS policies( api_group TEXT NOT NULL, name TEXT NOT NULL, namespace TEXT, - spec TEXT NOT NULL, - -- SHA1 hash - spec_hash CHAR(40) NOT NULL, + spec JSONB NOT NULL, severity TEXT, - UNIQUE (kind, api_group, name, namespace, spec_hash, severity) + UNIQUE (kind, api_group, name, namespace, spec, severity) ); -- This is required until we only support Postgres 15+ to utilize NULLS NOT DISTINCT. -- Partial indexes with 1 nullable unique field provided (e.g. A, B) -CREATE UNIQUE INDEX policies_null1 ON policies (kind, api_group, name, spec_hash, severity) WHERE namespace IS NULL; -CREATE UNIQUE INDEX policies_null2 ON policies (kind, api_group, name, namespace, spec_hash) WHERE severity IS NULL; +CREATE UNIQUE INDEX IF NOT EXISTS policies_null1 ON policies (kind, api_group, name, spec, severity) WHERE namespace IS NULL; +CREATE UNIQUE INDEX IF NOT EXISTS policies_null2 ON policies (kind, api_group, name, namespace, spec) WHERE severity IS NULL; -- Partial index with no nullable unique fields provided (e.g. AB) -CREATE UNIQUE INDEX policies_null3 ON policies (kind, api_group, name, spec_hash) WHERE namespace IS NULL AND severity IS NULL; +CREATE UNIQUE INDEX IF NOT EXISTS policies_null3 ON policies (kind, api_group, name, spec) WHERE namespace IS NULL AND severity IS NULL; -CREATE INDEX IF NOT EXISTS idx_policies_spec_hash ON policies (spec_hash); +CREATE INDEX IF NOT EXISTS idx_policies_spec ON policies (spec); CREATE TABLE IF NOT EXISTS compliance_events( id serial PRIMARY KEY, diff --git a/controllers/complianceeventsapi/server.go b/controllers/complianceeventsapi/server.go index b08b3887..15e5674d 100644 --- a/controllers/complianceeventsapi/server.go +++ b/controllers/complianceeventsapi/server.go @@ -1,14 +1,10 @@ package complianceeventsapi import ( - "bytes" "context" - "crypto/sha1" // #nosec G505 -- for convenience, not cryptography "database/sql" - "encoding/hex" "encoding/json" "errors" - "fmt" "io" "net/http" "sync" @@ -113,7 +109,7 @@ func postComplianceEvent(db *sql.DB, w http.ResponseWriter, r *http.Request) { return } - if err := reqEvent.Validate(); err != nil { + if err := reqEvent.Validate(r.Context(), db); err != nil { writeErrMsgJSON(w, err.Error(), http.StatusBadRequest) return @@ -165,8 +161,8 @@ func postComplianceEvent(db *sql.DB, w http.ResponseWriter, r *http.Request) { return } - // remove the spec to only respond with the specHash - reqEvent.Policy.Spec = "" + // remove the spec so it's not returned in the JSON. + reqEvent.Policy.Spec = nil resp, err := json.Marshal(reqEvent) if err != nil { @@ -201,6 +197,10 @@ func getClusterForeignKey(ctx context.Context, db *sql.DB, cluster Cluster) (int } func getParentPolicyForeignKey(ctx context.Context, db *sql.DB, parent ParentPolicy) (int32, error) { + if parent.KeyID != 0 { + return parent.KeyID, nil + } + // Check cache parKey := parent.key() @@ -220,15 +220,8 @@ func getParentPolicyForeignKey(ctx context.Context, db *sql.DB, parent ParentPol } func getPolicyForeignKey(ctx context.Context, db *sql.DB, pol Policy) (int32, error) { - // Fill in missing fields that can be inferred from other fields - if pol.SpecHash == "" { - var buf bytes.Buffer - if err := json.Compact(&buf, []byte(pol.Spec)); err != nil { - return 0, err // This kind of error would have been found during validation - } - - sum := sha1.Sum(buf.Bytes()) // #nosec G401 -- for convenience, not cryptography - pol.SpecHash = hex.EncodeToString(sum[:]) + if pol.KeyID != 0 { + return pol.KeyID, nil } // Check cache @@ -239,29 +232,6 @@ func getPolicyForeignKey(ctx context.Context, db *sql.DB, pol Policy) (int32, er return key.(int32), nil } - if pol.Spec == "" { - row := db.QueryRowContext( - ctx, "SELECT spec FROM policies WHERE spec_hash=$1 LIMIT 1", pol.SpecHash, - ) - if row.Err() != nil { - return 0, fmt.Errorf("could not determine the spec from the provided spec hash: %w", row.Err()) - } - - err := row.Scan(&pol.Spec) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return 0, fmt.Errorf( - "%w: could not determine the spec from the provided spec hash; the spec is required in the request", - errRequiredFieldNotProvided, - ) - } - - return 0, fmt.Errorf( - "the database returned an unexpected spec value for the provided spec hash: %w", err, - ) - } - } - err := pol.GetOrCreate(ctx, db) if err != nil { return 0, err diff --git a/controllers/complianceeventsapi/types.go b/controllers/complianceeventsapi/types.go index 977bdd74..b46be6fb 100644 --- a/controllers/complianceeventsapi/types.go +++ b/controllers/complianceeventsapi/types.go @@ -1,12 +1,9 @@ package complianceeventsapi import ( - "bytes" "context" - "crypto/sha1" // #nosec G505 -- for convenience, not cryptography "database/sql" "database/sql/driver" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -28,13 +25,17 @@ type dbRow interface { } type ComplianceEvent struct { + EventID int32 `json:"id"` Cluster Cluster `json:"cluster"` Event EventDetails `json:"event"` - ParentPolicy *ParentPolicy `json:"parent_policy,omitempty"` //nolint:tagliatelle + ParentPolicy *ParentPolicy `json:"parent_policy"` //nolint:tagliatelle Policy Policy `json:"policy"` } -func (ce ComplianceEvent) Validate() error { +// Validate ensures that a valid POST request for a compliance event is set. This means that if the shorthand approach +// of providing parent_policy.id and/or policy.id is used, the other fields for ParentPolicy and Policy will not be +// present. +func (ce ComplianceEvent) Validate(ctx context.Context, db *sql.DB) error { errs := make([]error, 0) if err := ce.Cluster.Validate(); err != nil { @@ -42,7 +43,33 @@ func (ce ComplianceEvent) Validate() error { } if ce.ParentPolicy != nil { - if err := ce.ParentPolicy.Validate(); err != nil { + if ce.ParentPolicy.KeyID != 0 { + row := db.QueryRowContext( + ctx, `SELECT EXISTS(SELECT * FROM parent_policies WHERE id=$1);`, ce.ParentPolicy.KeyID, + ) + + if row.Err() != nil { + log.Error(row.Err(), "Failed to query for the existence of the parent policy ID") + + return errors.New("failed to determine if parent_policy.id is valid") + } + + var exists bool + + err := row.Scan(&exists) + if err != nil { + log.Error(row.Err(), "Failed to scan for the existence of the parent policy ID") + + return errors.New("failed to determine if parent_policy.id is valid") + } + + if exists { + // If the user provided extra data, ignore it since it won't be validated that it matches the database + ce.ParentPolicy = &ParentPolicy{KeyID: ce.ParentPolicy.KeyID} + } else { + errs = append(errs, fmt.Errorf("%w: parent_policy.id not found", errInvalidInput)) + } + } else if err := ce.ParentPolicy.Validate(); err != nil { errs = append(errs, err) } } @@ -51,7 +78,30 @@ func (ce ComplianceEvent) Validate() error { errs = append(errs, err) } - if err := ce.Policy.Validate(); err != nil { + if ce.Policy.KeyID != 0 { + row := db.QueryRowContext(ctx, `SELECT EXISTS(SELECT * FROM policies WHERE id=$1);`, ce.Policy.KeyID) + if row.Err() != nil { + log.Error(row.Err(), "Failed to query for the existence of the policy ID") + + return errors.New("failed to determine if policy.id is valid") + } + + var exists bool + + err := row.Scan(&exists) + if err != nil { + log.Error(row.Err(), "Failed to scan for the existence of the policy ID") + + return errors.New("failed to determine if policy.id is valid") + } + + if exists { + // If the user provided extra data, ignore it since it won't be validated that it matches the database + ce.Policy = Policy{KeyID: ce.Policy.KeyID} + } else { + errs = append(errs, fmt.Errorf("%w: policy.id not found", errInvalidInput)) + } + } else if err := ce.Policy.Validate(); err != nil { errs = append(errs, err) } @@ -139,8 +189,8 @@ type EventDetails struct { Compliance string `db:"compliance" json:"compliance"` Message string `db:"message" json:"message"` Timestamp time.Time `db:"timestamp" json:"timestamp"` - Metadata JSONMap `db:"metadata" json:"metadata,omitempty"` - ReportedBy *string `db:"reported_by" json:"reported_by,omitempty"` //nolint:tagliatelle + Metadata JSONMap `db:"metadata" json:"metadata"` + ReportedBy *string `db:"reported_by" json:"reported_by"` //nolint:tagliatelle } func (e EventDetails) Validate() error { @@ -180,12 +230,12 @@ func (e *EventDetails) InsertQuery() (string, []any) { } type ParentPolicy struct { - KeyID int32 `db:"id" json:"-"` + KeyID int32 `db:"id" json:"id"` Name string `db:"name" json:"name"` Namespace string `db:"namespace" json:"namespace"` - Categories pq.StringArray `db:"categories" json:"categories,omitempty"` - Controls pq.StringArray `db:"controls" json:"controls,omitempty"` - Standards pq.StringArray `db:"standards" json:"standards,omitempty"` + Categories pq.StringArray `db:"categories" json:"categories"` + Controls pq.StringArray `db:"controls" json:"controls"` + Standards pq.StringArray `db:"standards" json:"standards"` } func (p ParentPolicy) Validate() error { @@ -217,11 +267,36 @@ func (p *ParentPolicy) SelectQuery(returnedColumns ...string) (string, []any) { } sql := fmt.Sprintf( - `SELECT %s FROM parent_policies `+ - `WHERE categories=$1 AND controls=$2 AND name=$3 AND namespace=$4 AND standards=$5`, + `SELECT %s FROM parent_policies WHERE name=$1 AND namespace=$2`, strings.Join(returnedColumns, ", "), ) - values := []any{p.Categories, p.Controls, p.Name, p.Namespace, p.Standards} + values := []any{p.Name, p.Namespace} + + columnCount := 2 + + if p.Categories == nil { + sql += " AND categories IS NULL" + } else { + columnCount++ + sql += fmt.Sprintf(" AND categories=$%d", columnCount) + values = append(values, p.Categories) + } + + if p.Controls == nil { + sql += " AND controls IS NULL" + } else { + columnCount++ + sql += fmt.Sprintf(" AND controls=$%d", columnCount) + values = append(values, p.Controls) + } + + if p.Standards == nil { + sql += " AND standards IS NULL" + } else { + columnCount++ + sql += fmt.Sprintf(" AND standards=$%d", columnCount) + values = append(values, p.Standards) + } return sql, values } @@ -235,14 +310,13 @@ func (p ParentPolicy) key() string { } type Policy struct { - KeyID int32 `db:"id" json:"-"` + KeyID int32 `db:"id" json:"id"` Kind string `db:"kind" json:"kind"` APIGroup string `db:"api_group" json:"apiGroup"` Name string `db:"name" json:"name"` - Namespace *string `db:"namespace" json:"namespace,omitempty"` - Spec string `db:"spec" json:"spec,omitempty"` - SpecHash string `db:"spec_hash" json:"specHash"` - Severity *string `db:"severity" json:"severity,omitempty"` + Namespace *string `db:"namespace" json:"namespace"` + Spec JSONMap `db:"spec" json:"spec,omitempty"` + Severity *string `db:"severity" json:"severity"` } func (p *Policy) Validate() error { @@ -260,24 +334,8 @@ func (p *Policy) Validate() error { errs = append(errs, fmt.Errorf("%w: policy.name", errRequiredFieldNotProvided)) } - if p.Spec == "" && p.SpecHash == "" { - errs = append(errs, fmt.Errorf("%w: policy.spec or policy.specHash", errRequiredFieldNotProvided)) - } - - if p.Spec != "" { - var buf bytes.Buffer - if err := json.Compact(&buf, []byte(p.Spec)); err != nil { - errs = append(errs, fmt.Errorf("%w: policy.spec is not valid JSON: %w", errInvalidInput, err)) - } else if buf.String() != p.Spec { - errs = append(errs, fmt.Errorf("%w: policy.spec is not compact JSON", errInvalidInput)) - } else if p.SpecHash != "" { - sum := sha1.Sum(buf.Bytes()) // #nosec G401 -- for convenience, not cryptography - - if p.SpecHash != hex.EncodeToString(sum[:]) { - errs = append(errs, fmt.Errorf("%w: policy.specHash does not match the compact policy.Spec", - errInvalidInput)) - } - } + if p.Spec == nil { + errs = append(errs, fmt.Errorf("%w: policy.spec", errRequiredFieldNotProvided)) } return errors.Join(errs...) @@ -285,9 +343,9 @@ func (p *Policy) Validate() error { func (p *Policy) InsertQuery() (string, []any) { sql := `INSERT INTO policies` + - `(api_group, kind, name, namespace, severity, spec, spec_hash)` + - `VALUES($1, $2, $3, $4, $5, $6, $7)` - values := []any{p.APIGroup, p.Kind, p.Name, p.Namespace, p.Severity, p.Spec, p.SpecHash} + `(api_group, kind, name, namespace, severity, spec)` + + `VALUES($1, $2, $3, $4, $5, $6)` + values := []any{p.APIGroup, p.Kind, p.Name, p.Namespace, p.Severity, p.Spec} return sql, values } @@ -299,10 +357,29 @@ func (p *Policy) SelectQuery(returnedColumns ...string) (string, []any) { sql := fmt.Sprintf( `SELECT %s FROM policies `+ - `WHERE api_group=$1 AND kind=$2 AND name=$3 AND namespace=$4 AND severity=$5 AND spec_hash=$6`, + `WHERE api_group=$1 AND kind=$2 AND name=$3 AND spec=$4`, strings.Join(returnedColumns, ", "), ) - values := []any{p.APIGroup, p.Kind, p.Name, p.Namespace, p.Severity, p.SpecHash} + + values := []any{p.APIGroup, p.Kind, p.Name, p.Spec} + + columnCount := 4 + + if p.Namespace == nil { + sql += " AND namespace is NULL" + } else { + columnCount++ + sql += fmt.Sprintf(" AND namespace=$%d", columnCount) + values = append(values, p.Namespace) + } + + if p.Severity == nil { + sql += " AND severity is NULL" + } else { + columnCount++ + sql += fmt.Sprintf(" AND severity=$%d", columnCount) + values = append(values, p.Severity) + } return sql, values } @@ -311,36 +388,22 @@ func (p *Policy) GetOrCreate(ctx context.Context, db *sql.DB) error { return getOrCreate(ctx, db, p) } -func (p *Policy) key() policyKey { - key := policyKey{ - Kind: p.Kind, - APIGroup: p.APIGroup, - Name: p.Name, - } +func (p *Policy) key() string { + var namespace string if p.Namespace != nil { - key.Namespace = *p.Namespace + namespace = *p.Namespace } - if p.SpecHash != "" { - key.SpecHash = p.SpecHash - } + var severity string if p.Severity != nil { - key.Severity = *p.Severity + severity = *p.Severity } - return key -} - -type policyKey struct { - Kind string - APIGroup string - Name string - Namespace string - ParentID string - SpecHash string - Severity string + // Note that as of Go 1.20, it sorts the keys in the underlying map of p.Spec, which is why this is deterministic. + // https://github.com/golang/go/blob/97c8ff8d53759e7a82b1862403df1694f2b6e073/src/fmt/print.go#L816-L828 + return fmt.Sprintf("%s;%s;%s;%v;%v;%v", p.APIGroup, p.Kind, p.Name, namespace, severity, p.Spec) } type JSONMap map[string]interface{} diff --git a/controllers/complianceeventsapi/types_test.go b/controllers/complianceeventsapi/types_test.go index 5610fc60..baf96bb5 100644 --- a/controllers/complianceeventsapi/types_test.go +++ b/controllers/complianceeventsapi/types_test.go @@ -96,11 +96,7 @@ func TestParentPolicyValidation(t *testing.T) { } func TestPolicyValidation(t *testing.T) { - basespec := `{"test":"one","severity":"low"}` - basehash := "cb84fe29e44202e3aeb46d39ba46993f60cdc6af" - badhash := "foobarbaz" - badspec := `{foo: bar: baz` - noncompactspec := `{"foo" : "bar" }` + var basespec JSONMap = map[string]interface{}{"test": "one", "severity": "low"} tests := map[string]struct { obj Policy @@ -111,16 +107,14 @@ func TestPolicyValidation(t *testing.T) { Kind: "policy", APIGroup: "v1", Spec: basespec, - SpecHash: basehash, }, "field not provided: policy.name", }, "no API group": { Policy{ - Kind: "policy", - Name: "foobar", - Spec: basespec, - SpecHash: basehash, + Kind: "policy", + Name: "foobar", + Spec: basespec, }, "field not provided: policy.apiGroup", }, @@ -129,7 +123,6 @@ func TestPolicyValidation(t *testing.T) { APIGroup: "v1", Name: "foobar", Spec: basespec, - SpecHash: basehash, }, "field not provided: policy.kind", }, @@ -139,35 +132,7 @@ func TestPolicyValidation(t *testing.T) { APIGroup: "v1", Name: "foobar", }, - "field not provided: policy.spec or policy.specHash", - }, - "not valid json": { - Policy{ - Kind: "policy", - APIGroup: "v1", - Name: "foobar", - Spec: badspec, - }, - "policy.spec is not valid JSON", - }, - "not compact json": { - Policy{ - Kind: "policy", - APIGroup: "v1", - Name: "foobar", - Spec: noncompactspec, - }, - "policy.spec is not compact JSON", - }, - "not matching hash": { - Policy{ - Kind: "policy", - APIGroup: "v1", - Name: "foobar", - Spec: basespec, - SpecHash: badhash, - }, - "policy.specHash does not match the compact policy.Spec", + "field not provided: policy.spec", }, } diff --git a/test/e2e/case18_compliance_api_test.go b/test/e2e/case18_compliance_api_test.go index 7342ce84..0d108b46 100644 --- a/test/e2e/case18_compliance_api_test.go +++ b/test/e2e/case18_compliance_api_test.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "net/http" + "time" "github.com/lib/pq" . "github.com/onsi/ginkgo/v2" @@ -20,6 +21,10 @@ import ( "open-cluster-management.io/governance-policy-propagator/controllers/complianceeventsapi" ) +const eventsEndpoint = "http://localhost:5480/api/v1/compliance-events" + +var httpClient = http.Client{Timeout: 30 * time.Second} + func getTableNames(db *sql.DB) ([]string, error) { tableNameRows, err := db.Query("SELECT tablename FROM pg_tables WHERE schemaname = current_schema()") if err != nil { @@ -47,7 +52,7 @@ func getTableNames(db *sql.DB) ([]string, error) { } // Note: These tests require a running Postgres server running in the Kind cluster from the "postgres" Make target. -var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, func() { +var _ = Describe("Test the compliance events API", Label("compliance-events-api"), Ordered, func() { var k8sConfig *rest.Config var db *sql.DB @@ -133,7 +138,7 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, "kind": "ConfigurationPolicy", "name": "etcd-encryption1", "namespace": "local-cluster", - "spec": "{\"test\":\"one\",\"severity\":\"low\"}", + "spec": {"test": "one", "severity": "low"}, "severity": "low" }, "event": { @@ -214,12 +219,11 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, apiGroup string name string ns *string - spec *string - specHash *string + spec complianceeventsapi.JSONMap severity *string ) - err := rows.Scan(&id, &kind, &apiGroup, &name, &ns, &spec, &specHash, &severity) + err := rows.Scan(&id, &kind, &apiGroup, &name, &ns, &spec, &severity) Expect(err).ToNot(HaveOccurred()) Expect(id).NotTo(Equal(0)) @@ -228,9 +232,7 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, Expect(ns).ToNot(BeNil()) Expect(*ns).To(Equal("local-cluster")) Expect(spec).ToNot(BeNil()) - Expect(*spec).To(Equal(`{"test":"one","severity":"low"}`)) - Expect(specHash).ToNot(BeNil()) - Expect(*specHash).To(Equal("cb84fe29e44202e3aeb46d39ba46993f60cdc6af")) + Expect(spec).To(BeEquivalentTo(map[string]any{"test": "one", "severity": "low"})) Expect(severity).ToNot(BeNil()) Expect(*severity).To(Equal("low")) @@ -263,11 +265,11 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, &metadata, &reportedBy) Expect(err).ToNot(HaveOccurred()) - Expect(id).NotTo(Equal(0)) - Expect(clusterID).NotTo(Equal(0)) - Expect(policyID).NotTo(Equal(0)) + Expect(id).To(Equal(1)) + Expect(clusterID).To(Equal(1)) + Expect(policyID).To(Equal(1)) Expect(parentPolicyID).NotTo(BeNil()) - Expect(*parentPolicyID).NotTo(Equal(0)) + Expect(*parentPolicyID).To(Equal(1)) Expect(compliance).To(Equal("NonCompliant")) Expect(message).To(Equal("configmaps [etcd] not found in namespace default")) Expect(timestamp).To(Equal("2023-01-01T01:01:01.111Z")) @@ -291,7 +293,7 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, "apiGroup": "policy.open-cluster-management.io", "kind": "ConfigurationPolicy", "name": "etcd-encryption2", - "spec": "{\"test\":\"two\"}" + "spec": {"test": "two"} }, "event": { "compliance": "NonCompliant", @@ -309,7 +311,7 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, "apiGroup": "policy.open-cluster-management.io", "kind": "ConfigurationPolicy", "name": "etcd-encryption2", - "spec": "{\"different-spec-test\":\"two-and-a-half\"}" + "spec": {"different-spec-test": "two-and-a-half"} }, "event": { "compliance": "Compliant", @@ -339,7 +341,6 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, err := rows.Scan(&id, &name, &clusterID) Expect(err).ToNot(HaveOccurred()) - Expect(id).NotTo(Equal(0)) clusternames = append(clusternames, name) } @@ -350,7 +351,8 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, rows, err := db.Query("SELECT * FROM policies WHERE name = $1", "etcd-encryption2") Expect(err).ToNot(HaveOccurred()) - hashes := make([]string, 0) + rowCount := 0 + for rows.Next() { var ( id int @@ -358,23 +360,18 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, apiGroup string name string ns *string - spec *string - specHash *string + spec complianceeventsapi.JSONMap severity *string ) - err := rows.Scan(&id, &kind, &apiGroup, &name, &ns, &spec, &specHash, &severity) + err := rows.Scan(&id, &kind, &apiGroup, &name, &ns, &spec, &severity) Expect(err).ToNot(HaveOccurred()) - Expect(id).NotTo(Equal(0)) - Expect(specHash).ToNot(BeNil()) - hashes = append(hashes, *specHash) + rowCount++ + Expect(id).To(Equal(1 + rowCount)) } - Expect(hashes).To(ConsistOf( - "8cfd1ee0a4b10aadaa4e4f3b2b9ec15e6616c1e5", - "2c6c7170351bfaa98eb45453b93766c18d24fa04", - )) + Expect(rowCount).To(Equal(2)) }) It("Should have created both events in a table", func() { @@ -400,12 +397,12 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, &metadata, &reportedBy) Expect(err).ToNot(HaveOccurred()) + messages = append(messages, message) + Expect(id).NotTo(Equal(0)) Expect(clusterID).NotTo(Equal(0)) - Expect(policyID).NotTo(Equal(0)) + Expect(policyID).To(Equal(1 + len(messages))) Expect(parentPolicyID).To(BeNil()) - - messages = append(messages, message) } Expect(messages).To(ConsistOf( @@ -433,7 +430,7 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, "apiGroup": "policy.open-cluster-management.io", "kind": "ConfigurationPolicy", "name": "common", - "spec": "{\"test\":\"three\",\"severity\":\"low\"}", + "spec": {"test": "three", "severity": "low"}, "severity": "low" }, "event": { @@ -443,25 +440,17 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, } }`) - // payload2 just uses the specHash for the policy. + // payload2 just uses the ids for the policy and parent_policy. payload2 := []byte(`{ "cluster": { "name": "cluster4", "cluster_id": "test3-cluster4-fake-uuid-4" }, "parent_policy": { - "name": "common-parent", - "namespace": "policies", - "categories": ["cat-3", "cat-4"], - "controls": ["ctrl-2"], - "standards": ["stand-2"] + "id": 2 }, "policy": { - "apiGroup": "policy.open-cluster-management.io", - "kind": "ConfigurationPolicy", - "name": "common", - "specHash": "5382228c69c6017d4efbd6e42717930cb2020da0", - "severity": "low" + "id": 4 }, "event": { "compliance": "NonCompliant", @@ -528,7 +517,7 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, rows, err := db.Query("SELECT * FROM policies WHERE name = $1", "common") Expect(err).ToNot(HaveOccurred()) - hashes := make([]string, 0) + specs := make([]complianceeventsapi.JSONMap, 0, 1) for rows.Next() { var ( id int @@ -536,22 +525,19 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, apiGroup string name string ns *string - spec *string - specHash *string + spec complianceeventsapi.JSONMap severity *string ) - err := rows.Scan(&id, &kind, &apiGroup, &name, &ns, &spec, &specHash, &severity) + err := rows.Scan(&id, &kind, &apiGroup, &name, &ns, &spec, &severity) Expect(err).ToNot(HaveOccurred()) Expect(id).NotTo(Equal(0)) - Expect(specHash).ToNot(BeNil()) - hashes = append(hashes, *specHash) + specs = append(specs, spec) } - Expect(hashes).To(ConsistOf( - "5382228c69c6017d4efbd6e42717930cb2020da0", - )) + Expect(specs).To(HaveLen(1)) + Expect(specs[0]).To(BeEquivalentTo(map[string]any{"test": "three", "severity": "low"})) }) It("Should have created both events in a table", func() { @@ -579,9 +565,9 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, Expect(id).NotTo(Equal(0)) Expect(clusterID).NotTo(Equal(0)) - Expect(policyID).NotTo(Equal(0)) + Expect(policyID).To(Equal(4)) Expect(parentPolicyID).NotTo(BeNil()) - Expect(*parentPolicyID).NotTo(Equal(0)) + Expect(*parentPolicyID).To(Equal(2)) timestamps = append(timestamps, timestamp) } @@ -609,7 +595,7 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, "apiGroup": "policy.open-cluster-management.io", "kind": "ConfigurationPolicy", "name": "common-a", - "spec": "{\"test\":\"four\",\"severity\":\"low\"}", + "spec": {"test": "four", "severity": "low"}, "severity": "low" }, "event": { @@ -634,7 +620,7 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, "apiGroup": "policy.open-cluster-management.io", "kind": "ConfigurationPolicy", "name": "common-a", - "spec": "{\"test\":\"four\",\"severity\":\"low\"}", + "spec": {"test": "four", "severity": "low"}, "severity": "low" }, "event": { @@ -660,7 +646,7 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, "apiGroup": "policy.open-cluster-management.io", "kind": "ConfigurationPolicy", "name": "common-a", - "spec": "{\"test\":\"four\",\"severity\":\"low\"}", + "spec": {"test": "four", "severity": "low"}, "severity": "low" }, "event": { @@ -719,16 +705,14 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, apiGroup string name string ns *string - spec *string - specHash string + spec complianceeventsapi.JSONMap severity *string ) - err := rows.Scan(&id, &kind, &apiGroup, &name, &ns, &spec, &specHash, &severity) + err := rows.Scan(&id, &kind, &apiGroup, &name, &ns, &spec, &severity) Expect(err).ToNot(HaveOccurred()) Expect(id).NotTo(Equal(0)) - Expect(specHash).ToNot(BeNil()) ids = append(ids, id) } @@ -751,7 +735,7 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, "apiGroup": "policy.open-cluster-management.io", "kind": "ConfigurationPolicy", "name": "common-b", - "spec": "{\"test\":\"four\",\"severity\":\"low\"}", + "spec": {"test": "four", "severity": "low"}, "severity": "low", "namespace": "default" }, @@ -776,7 +760,7 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, "apiGroup": "policy.open-cluster-management.io", "kind": "ConfigurationPolicy", "name": "common-b", - "spec": "{\"test\":\"four\",\"severity\":\"low\"}", + "spec": {"test": "four", "severity": "low"}, "severity": "low" }, "event": { @@ -825,7 +809,7 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, ids := make([]int, 0) names := make([]string, 0) namespaces := make([]string, 0) - hashes := make([]string, 0) + specs := make([]complianceeventsapi.JSONMap, 0, 2) for rows.Next() { var ( id int @@ -833,19 +817,17 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, apiGroup string name string ns *string - spec *string - specHash string + spec complianceeventsapi.JSONMap severity *string ) - err := rows.Scan(&id, &kind, &apiGroup, &name, &ns, &spec, &specHash, &severity) + err := rows.Scan(&id, &kind, &apiGroup, &name, &ns, &spec, &severity) Expect(err).ToNot(HaveOccurred()) Expect(id).NotTo(Equal(0)) - Expect(specHash).ToNot(BeNil()) ids = append(ids, id) names = append(names, name) - hashes = append(hashes, specHash) + specs = append(specs, spec) if ns != nil { namespaces = append(namespaces, *ns) @@ -856,7 +838,7 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, Expect(ids[0]).ToNot(Equal(ids[1])) Expect(names[0]).To(Equal(names[1])) Expect(namespaces).To(ConsistOf("default")) - Expect(hashes[0]).To(Equal(hashes[1])) + Expect(specs[0]).To(Equal(specs[1])) }) }) @@ -871,7 +853,7 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, "apiGroup": "policy.open-cluster-management.io", "kind": "ConfigurationPolicy", "name": "validity", - "spec": "{\"test\":\"validity\",\"severity\":\"low\"}" + "spec": {"test":"validity", "severity": "low"} }, "event": { "compliance": "Compliant", @@ -894,7 +876,7 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, "apiGroup": "policy.open-cluster-management.io", "kind": "ConfigurationPolicy", "name": "validity", - "spec": "{\"test\":\"validity\",\"severity\":\"low\"}", + "spec": {"test":"validity", "severity": "low"}, "severity": "low" }, "event": { @@ -919,7 +901,7 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, "apiGroup": "policy.open-cluster-management.io", "kind": "ConfigurationPolicy", "name": "validity", - "spec": "{\"test\":\"validity\",\"severity\":\"low\"}", + "spec": {"test": "validity", "severity": "low"}, "severity": "low" }, "event": { @@ -929,32 +911,6 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, }`)), "5s", "1s").Should(MatchError(ContainSubstring("Got non-201 status code 400"))) }) - It("should require the policy spec and hash to match", func(ctx context.Context) { - Eventually(postEvent(ctx, []byte(`{ - "cluster": { - "name": "validity-test", - "cluster_id": "test-validity-fake-uuid" - }, - "parent_policy": { - "name": "validity-parent", - "namespace": "policies" - }, - "policy": { - "apiGroup": "policy.open-cluster-management.io", - "kind": "ConfigurationPolicy", - "name": "validity", - "spec": "{\"test\":\"validity\",\"severity\":\"low\"}", - "severity": "low", - "specHash": "foobar" - }, - "event": { - "compliance": "Compliant", - "message": "configmaps [valid] valid in namespace valid", - "timestamp": "2023-09-09T09:09:09.999Z" - } - }`)), "5s", "1s").Should(MatchError(ContainSubstring("Got non-201 status code 400"))) - }) - It("should require the input to be valid JSON", func(ctx context.Context) { Eventually(postEvent(ctx, []byte(`{ foo: bar: baz @@ -970,7 +926,7 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, "apiGroup": "policy.open-cluster-management.io", "kind": "ConfigurationPolicy", "name": "validity", - "spec": "{\"test\":\"validity\",\"severity\":\"low\"}", + "spec": {"test": "validity", "severity": "low"}, "severity": "low", "specHash": "foobar" }, @@ -989,14 +945,10 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, "cluster_id": "test-validity-fake-uuid" }, "parent_policy": { - "name": "validity-parent", - "namespace": "policies" + "id": 1231234 }, "policy": { - "apiGroup": "policy.open-cluster-management.io", - "kind": "ConfigurationPolicy", - "name": "validity", - "specHash": "0123456789abcdefzzzzzzzzzzzzzzzzzzzzzzzz" + "id": 123123 }, "event": { "compliance": "Compliant", @@ -1004,16 +956,91 @@ var _ = Describe("Test policy webhook", Label("compliance-events-api"), Ordered, "timestamp": "2023-09-09T09:09:09.999Z" } }`)), "5s", "1s").Should(MatchError(ContainSubstring( - "could not determine the spec from the provided spec hash; the spec is required in the request", + `invalid input: parent_policy.id not found\\ninvalid input: policy.id not found`, ))) }) }) }) }) +var _ = Describe("Test query generation", Label("compliance-events-api"), func() { + It("Tests the select query for a cluster", func() { + cluster := complianceeventsapi.Cluster{ + ClusterID: "my-cluster-id", + Name: "my-cluster", + } + sql, vals := cluster.SelectQuery("id", "spec") + Expect(sql).To(Equal("SELECT id, spec FROM clusters WHERE cluster_id=$1 AND name=$2")) + Expect(vals).To(HaveLen(2)) + }) + + It("Tests the select query for a minimum parent policy", func() { + parent := complianceeventsapi.ParentPolicy{ + Name: "parent-a", + Namespace: "policies", + } + sql, vals := parent.SelectQuery("id", "spec") + Expect(sql).To(Equal( + "SELECT id, spec FROM parent_policies WHERE name=$1 AND namespace=$2 AND categories IS NULL AND " + + "controls IS NULL AND standards IS NULL", + )) + Expect(vals).To(HaveLen(2)) + }) + + It("Tests the select query for a parent policy with all options", func() { + parent := complianceeventsapi.ParentPolicy{ + Name: "parent-a", + Namespace: "policies", + Categories: pq.StringArray{"cat-1"}, + Controls: pq.StringArray{"control-1", "control-2"}, + Standards: pq.StringArray{"standard-1"}, + } + sql, vals := parent.SelectQuery("id") + Expect(sql).To(Equal( + "SELECT id FROM parent_policies WHERE name=$1 AND namespace=$2 AND categories=$3 AND controls=$4 " + + "AND standards=$5", + )) + Expect(vals).To(HaveLen(5)) + }) + + It("Tests the select query for a minimum policy", func() { + policy := complianceeventsapi.Policy{ + Name: "parent-a", + Kind: "ConfigurationPolicy", + APIGroup: "policy.open-cluster-management.io", + Spec: complianceeventsapi.JSONMap{"spec": "this-out"}, + } + sql, vals := policy.SelectQuery("id") + Expect(sql).To(Equal( + "SELECT id FROM policies WHERE api_group=$1 AND kind=$2 AND name=$3 AND spec=$4 AND namespace is NULL " + + "AND severity is NULL", + )) + Expect(vals).To(HaveLen(4)) + }) + + It("Tests the select query for a policy with all options", func() { + ns := "policies" + severity := "critical" + + policy := complianceeventsapi.Policy{ + Name: "parent-a", + Namespace: &ns, + Kind: "ConfigurationPolicy", + APIGroup: "policy.open-cluster-management.io", + Spec: complianceeventsapi.JSONMap{"spec": "this-out"}, + Severity: &severity, + } + sql, vals := policy.SelectQuery("id") + Expect(sql).To(Equal( + "SELECT id FROM policies WHERE api_group=$1 AND kind=$2 AND name=$3 AND spec=$4 AND namespace=$5 " + + "AND severity=$6", + )) + Expect(vals).To(HaveLen(6)) + }) +}) + func postEvent(ctx context.Context, payload []byte) error { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, - "http://localhost:5480/api/v1/compliance-events", bytes.NewBuffer(payload)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, eventsEndpoint, bytes.NewBuffer(payload)) if err != nil { return err } @@ -1022,9 +1049,7 @@ func postEvent(ctx context.Context, payload []byte) error { errs := make([]error, 0) - client := &http.Client{} - - resp, err := client.Do(req) + resp, err := httpClient.Do(req) if err != nil { errs = append(errs, err) }