Skip to content

Commit

Permalink
Add an API endpoint to generate a CSV compliance report
Browse files Browse the repository at this point in the history
Ref: https://issues.redhat.com/browse/ACM-6884
Signed-off-by: Yi Rae Kim <[email protected]>
(cherry picked from commit 6efa4c0)
  • Loading branch information
yiraeChristineKim authored and Magic Mirror committed Feb 14, 2024
1 parent 670993f commit a8f7c05
Show file tree
Hide file tree
Showing 3 changed files with 594 additions and 31 deletions.
277 changes: 246 additions & 31 deletions controllers/complianceeventsapi/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"crypto/x509"
"database/sql"
"encoding/csv"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -238,6 +239,19 @@ func (s *ComplianceAPIServer) Start(ctx context.Context, serverContext *Complian
getSingleComplianceEvent(serverContext.DB, w, r)
})

mux.HandleFunc("/api/v1/reports/compliance-events", func(w http.ResponseWriter, r *http.Request) {
// This header is for error writings
w.Header().Set("Content-Type", "application/json")

if r.Method != http.MethodGet {
writeErrMsgJSON(w, "Method not allowed", http.StatusMethodNotAllowed)

return
}

getComplianceEventsCSV(serverContext.DB, w, r)
})

serveErr := make(chan error)

go func() {
Expand Down Expand Up @@ -309,7 +323,7 @@ func splitQueryValue(value string) []string {

// parseQueryArgs will parse the HTTP request's query arguments and convert them to a usable format for constructing
// the SQL query. All defaults are set and any invalid query arguments result in an error being returned.
func parseQueryArgs(queryArgs url.Values) (*queryOptions, error) {
func parseQueryArgs(queryArgs url.Values, isCSV bool) (*queryOptions, error) {
parsed := &queryOptions{
Direction: "desc",
Page: 1,
Expand All @@ -320,6 +334,11 @@ func parseQueryArgs(queryArgs url.Values) (*queryOptions, error) {
NullFilters: []string{},
}

// Case return CSV file, default PerPage is 0. Unlimited
if isCSV {
parsed.PerPage = 0
}

for arg := range queryArgs {
valid := false

Expand Down Expand Up @@ -433,6 +452,17 @@ func parseQueryArgs(queryArgs url.Values) (*queryOptions, error) {
// generateGetComplianceEventsQuery will return a SELECT query with results ready to be parsed by
// scanIntoComplianceEvent. The caller is responsible for adding filters to the query.
func generateGetComplianceEventsQuery(includeSpec bool) string {
return fmt.Sprintf(`SELECT %s
FROM
compliance_events
LEFT JOIN clusters ON compliance_events.cluster_id = clusters.id
LEFT JOIN parent_policies ON compliance_events.parent_policy_id = parent_policies.id
LEFT JOIN policies ON compliance_events.policy_id = policies.id`,
strings.Join(generateSelectedArgs(includeSpec), ", "),
)
}

func generateSelectedArgs(includeSpec bool) []string {
selectArgs := []string{
"compliance_events.id",
"compliance_events.compliance",
Expand Down Expand Up @@ -460,14 +490,19 @@ func generateGetComplianceEventsQuery(includeSpec bool) string {
selectArgs = append(selectArgs, "policies.spec")
}

return fmt.Sprintf(`SELECT %s
FROM
compliance_events
LEFT JOIN clusters ON compliance_events.cluster_id = clusters.id
LEFT JOIN parent_policies ON compliance_events.parent_policy_id = parent_policies.id
LEFT JOIN policies ON compliance_events.policy_id = policies.id`,
strings.Join(selectArgs, ", "),
)
return selectArgs
}

// generate Headers for CSV. "." replace by "_"
// Example: parent_policies.namespace -> parent_policies_namespace
func getCsvHeader(includeSpec bool) []string {
localSelectArgs := generateSelectedArgs(includeSpec)

for i, arg := range localSelectArgs {
localSelectArgs[i] = strings.ReplaceAll(arg, ".", "_")
}

return localSelectArgs
}

type Scannable interface {
Expand Down Expand Up @@ -672,7 +707,7 @@ func getWhereClause(options *queryOptions) (string, []any) {

// getComplianceEvents handles the list API endpoint for compliance events.
func getComplianceEvents(db *sql.DB, w http.ResponseWriter, r *http.Request) {
queryArgs, err := parseQueryArgs(r.URL.Query())
queryArgs, err := parseQueryArgs(r.URL.Query(), false)
if err != nil {
writeErrMsgJSON(w, err.Error(), http.StatusBadRequest)

Expand All @@ -682,27 +717,7 @@ func getComplianceEvents(db *sql.DB, w http.ResponseWriter, r *http.Request) {
// Note that the where clause could be an empty string if not filters were passed in the query arguments.
whereClause, filterValues := getWhereClause(queryArgs)

// Example query:
// SELECT compliance_events.id, compliance_events.compliance, ...
// FROM compliance_events
// LEFT JOIN clusters ON compliance_events.cluster_id = clusters.id
// LEFT JOIN parent_policies ON compliance_events.parent_policy_id = parent_policies.id
// LEFT JOIN policies ON compliance_events.policy_id = policies.id
// WHERE (policies.name=$1 OR policies.name=$2) AND (policies.kind=$3)
// ORDER BY compliance_events.timestamp desc
// LIMIT 20
// OFFSET 0 ROWS;
query := fmt.Sprintf(`%s%s
ORDER BY %s %s
LIMIT %d
OFFSET %d ROWS;`,
generateGetComplianceEventsQuery(queryArgs.IncludeSpec),
whereClause,
strings.Join(queryArgs.Sort, ", "),
queryArgs.Direction,
queryArgs.PerPage,
(queryArgs.Page-1)*queryArgs.PerPage,
)
query := getComplianceEventsQuery(whereClause, queryArgs)

rows, err := db.QueryContext(r.Context(), query, filterValues...)
if err == nil {
Expand Down Expand Up @@ -887,6 +902,206 @@ func postComplianceEvent(db *sql.DB,
}
}

func getComplianceEventsQuery(whereClause string, queryArgs *queryOptions) string {
// Getting CSV without the page argument
// Query should fetch all rows (unlimited)
if queryArgs.PerPage == 0 {
return fmt.Sprintf(`%s%s
ORDER BY %s %s;`,
generateGetComplianceEventsQuery(queryArgs.IncludeSpec),
whereClause,
strings.Join(queryArgs.Sort, ", "),
queryArgs.Direction,
)
}
// Example query
// SELECT compliance_events.id, compliance_events.compliance, ...
// FROM compliance_events
// LEFT JOIN clusters ON compliance_events.cluster_id = clusters.id
// LEFT JOIN parent_policies ON compliance_events.parent_policy_id = parent_policies.id
// LEFT JOIN policies ON compliance_events.policy_id = policies.id
// WHERE (policies.name=$1 OR policies.name=$2) AND (policies.kind=$3)
// ORDER BY compliance_events.timestamp desc
// LIMIT 20
// OFFSET 0 ROWS;
return fmt.Sprintf(`%s%s
ORDER BY %s %s
LIMIT %d
OFFSET %d ROWS;`,
generateGetComplianceEventsQuery(queryArgs.IncludeSpec),
whereClause,
strings.Join(queryArgs.Sort, ", "),
queryArgs.Direction,
queryArgs.PerPage,
(queryArgs.Page-1)*queryArgs.PerPage,
)
}

func getComplianceEventsCSV(db *sql.DB, w http.ResponseWriter, r *http.Request) {
queryArgs, err := parseQueryArgs(r.URL.Query(), true)
if err != nil {
writeErrMsgJSON(w, err.Error(), http.StatusBadRequest)

return
}

// Note that the where clause could be an empty string if no filters were passed in the query arguments.
whereClause, filterValues := getWhereClause(queryArgs)

query := getComplianceEventsQuery(whereClause, queryArgs)

rows, err := db.QueryContext(r.Context(), query, filterValues...)
if err == nil {
err = rows.Err()
}

if err != nil {
log.Error(err, "Failed to query for compliance events")
writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)

return
}

defer rows.Close()

headers := getCsvHeader(queryArgs.IncludeSpec)

writer := csv.NewWriter(w)

err = writer.Write(headers)
if err != nil {
log.Error(err, "Failed to write csv header")
writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)

return
}

for rows.Next() {
ce, err := scanIntoComplianceEvent(rows, queryArgs.IncludeSpec)
if err != nil {
log.Error(err, "Failed to unmarshal the database results")
writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)

return
}

stringValues := convertToCsvLine(ce, queryArgs.IncludeSpec)

err = writer.Write(stringValues)
if err != nil {
log.Error(err, "Failed to write csv list")
writeErrMsgJSON(w, "Internal Error", http.StatusInternalServerError)

return
}
}

w.Header().Set("Content-Disposition", "attachment; filename=reports.csv")
w.Header().Set("Content-Type", "text/csv")
// It's going to be divided into chunks. if the user don't get it all at once,
// the user can receive one by one in the meantime
w.Header().Set("Transfer-Encoding", "chunked")

writer.Flush()
}

func convertToCsvLine(ce *ComplianceEvent, includeSpec bool) []string {
nilString := ""

if ce.ParentPolicy == nil {
ce.ParentPolicy = &ParentPolicy{
KeyID: 0,
Name: "",
Namespace: "",
Categories: nil,
Controls: nil,
Standards: nil,
}
}

if ce.Event.ReportedBy == nil {
ce.Event.ReportedBy = &nilString
}

if ce.Policy.Severity == nil {
ce.Policy.Severity = &nilString
}

if ce.Policy.Namespace == nil {
ce.Policy.Namespace = &nilString
}

values := []string{
convertToString(ce.EventID),
convertToString(ce.Event.Compliance),
convertToString(ce.Event.Message),
convertToString(ce.Event.Metadata),
convertToString(*ce.Event.ReportedBy),
convertToString(ce.Event.Timestamp),
convertToString(ce.Cluster.ClusterID),
convertToString(ce.Cluster.Name),
convertToString(ce.ParentPolicy.KeyID),
convertToString(ce.ParentPolicy.Name),
convertToString(ce.ParentPolicy.Namespace),
convertToString(ce.ParentPolicy.Categories),
convertToString(ce.ParentPolicy.Controls),
convertToString(ce.ParentPolicy.Standards),
convertToString(ce.Policy.KeyID),
convertToString(ce.Policy.APIGroup),
convertToString(ce.Policy.Kind),
convertToString(ce.Policy.Name),
convertToString(*ce.Policy.Namespace),
convertToString(*ce.Policy.Severity),
}

if includeSpec {
values = append(values, convertToString(ce.Policy.Spec))
}

return values
}

func convertToString(v interface{}) string {
switch vv := v.(type) {
case *string:
if vv == nil {
return ""
}

return *vv
case string:
return vv
case int32:
// All int32 related id
if int(vv) == 0 {
return ""
}

return strconv.Itoa(int(vv))
case time.Time:
return vv.String()
case pq.StringArray:
// nil will be []
return strings.Join(vv, ", ")
case bool:
return strconv.FormatBool(vv)
case JSONMap:
if vv == nil {
return ""
}

jsonByte, err := json.MarshalIndent(vv, "", " ")
if err != nil {
return ""
}

return string(jsonByte)
default:
// case nil:
return fmt.Sprintf("%v", vv)
}
}

func getClusterForeignKey(ctx context.Context, db *sql.DB, cluster Cluster) (int32, error) {
// Check cache
key, ok := clusterKeyCache.Load(cluster.ClusterID)
Expand Down
Loading

0 comments on commit a8f7c05

Please sign in to comment.