From 1e4f1aeb839d18d0d0786eec83ee3db9acfa8601 Mon Sep 17 00:00:00 2001 From: "peter.reisinger" Date: Fri, 2 Aug 2024 10:39:49 +0100 Subject: [PATCH] execute args in go routines --- README.md | 4 +++ main.go | 60 ++++++++++++++++++--------------------- pkg/cert/location.go | 28 +++++++++--------- pkg/cert/location_test.go | 15 ++++++---- print.go | 4 +++ 5 files changed, 58 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index 8725656..37b7385 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,10 @@ certinfo [flags] [| ...] +---------------+---------------------------------------------------------------------------------------------------+ ``` +If you need to run against multiple hosts, it is faster to execute command with multiple arguments e.g. +`certinfo -insecure -expiry google.com:443 amazon.com:443 ...` rather than executing command multiple times. Args are +executed concurrently and much faster. + Flags can be set as env. variable as well (`CERTINFO_=true` e.g. `CERTINFO_INSECURE=true`) and can be then overridden with a flag. diff --git a/main.go b/main.go index 67cb9be..6e1be72 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "os" "strconv" "strings" + "sync" ) var Version = "dev" @@ -45,40 +46,15 @@ func LoadCertificatesLocations(flags Flags) cert.CertificateLocations { var certificateLocations cert.CertificateLocations if flags.Clipboard { - certificateLocation, err := cert.LoadCertificateFromClipboard() - if err != nil { - printCertFileError("clipboard", err) - return nil - } - certificateLocations = append(certificateLocations, certificateLocation) + certificateLocations = append(certificateLocations, cert.LoadCertificateFromClipboard()) } if len(flags.Args) > 0 { - for _, arg := range flags.Args { - - var certificateLocation cert.CertificateLocation - var err error - if isTCPNetworkAddress(arg) { - certificateLocation, err = cert.LoadCertificatesFromNetwork(arg, flags.Insecure) - } else { - certificateLocation, err = cert.LoadCertificatesFromFile(arg) - } - - if err != nil { - printCertFileError(arg, err) - continue - } - certificateLocations = append(certificateLocations, certificateLocation) - } + certificateLocations = append(certificateLocations, loadFromArgs(flags.Args, flags.Insecure)...) } if isStdin() { - certificateLocation, err := cert.LoadCertificateFromStdin() - if err != nil { - printCertFileError("stdin", err) - return nil - } - certificateLocations = append(certificateLocations, certificateLocation) + certificateLocations = append(certificateLocations, cert.LoadCertificateFromStdin()) } if len(certificateLocations) > 0 { @@ -91,11 +67,31 @@ func LoadCertificatesLocations(flags Flags) cert.CertificateLocations { return nil } -func printCertFileError(fileName string, err error) { +func loadFromArgs(args []string, insecure bool) cert.CertificateLocations { + + out := make(chan cert.CertificateLocation) + go func() { + var wg sync.WaitGroup + for _, arg := range args { + wg.Add(1) + go func() { + defer wg.Done() + if isTCPNetworkAddress(arg) { + out <- cert.LoadCertificatesFromNetwork(arg, insecure) + return + } + out <- cert.LoadCertificatesFromFile(arg) + }() + } + wg.Wait() + close(out) + }() - fmt.Printf("--- [%s] ---\n", fileName) - fmt.Println(err) - fmt.Println() + var certificateLocations cert.CertificateLocations + for location := range out { + certificateLocations = append(certificateLocations, location) + } + return certificateLocations } func isTCPNetworkAddress(arg string) bool { diff --git a/pkg/cert/location.go b/pkg/cert/location.go index ab045fe..26c2f3c 100644 --- a/pkg/cert/location.go +++ b/pkg/cert/location.go @@ -3,7 +3,6 @@ package cert import ( "bytes" "crypto/tls" - "errors" "fmt" "golang.design/x/clipboard" "io" @@ -53,11 +52,11 @@ func (c CertificateLocation) RemoveDuplicates() CertificateLocation { return c } -func LoadCertificatesFromNetwork(addr string, tlsSkipVerify bool) (CertificateLocation, error) { +func LoadCertificatesFromNetwork(addr string, tlsSkipVerify bool) CertificateLocation { conn, err := tls.DialWithDialer(&net.Dialer{Timeout: tlsDialTimeout}, "tcp", addr, &tls.Config{InsecureSkipVerify: tlsSkipVerify}) if err != nil { - return CertificateLocation{}, fmt.Errorf("tcp connection failed: %w", err) + return CertificateLocation{Path: fmt.Sprintf("%s: %v", addr, err)} } connectionState := conn.ConnectionState() @@ -73,52 +72,51 @@ func LoadCertificatesFromNetwork(addr string, tlsSkipVerify bool) (CertificateLo Path: addr, Certificates: FromX509Certificates(x509Certificates), VerifiedChains: verifiedChains, - }, nil + } } -func LoadCertificatesFromFile(fileName string) (CertificateLocation, error) { +func LoadCertificatesFromFile(fileName string) CertificateLocation { b, err := os.ReadFile(fileName) if err != nil { - return CertificateLocation{}, fmt.Errorf("skipping %s file: %w", fileName, err) + return CertificateLocation{Path: fmt.Sprintf("%s: %v", fileName, err)} } return loadCertificate(fileName, b) } -func LoadCertificateFromStdin() (CertificateLocation, error) { +func LoadCertificateFromStdin() CertificateLocation { content, err := io.ReadAll(os.Stdin) if err != nil { - return CertificateLocation{}, fmt.Errorf("reading stdin: %w", err) + return CertificateLocation{Path: fmt.Sprintf("stdin: %v", err)} } return loadCertificate("stdin", content) } -func LoadCertificateFromClipboard() (CertificateLocation, error) { +func LoadCertificateFromClipboard() CertificateLocation { if err := clipboard.Init(); err != nil { - return CertificateLocation{}, fmt.Errorf("unable to load from clipboard: %w", err) + return CertificateLocation{Path: fmt.Sprintf("clipboard: %v", err)} } content := clipboard.Read(clipboard.FmtText) if content == nil { - return CertificateLocation{}, errors.New("clipboard is empty") + return CertificateLocation{Path: "clipboard is empty"} } - return loadCertificate("clipboard", content) } -func loadCertificate(fileName string, data []byte) (CertificateLocation, error) { +func loadCertificate(fileName string, data []byte) CertificateLocation { certificates, err := FromBytes(bytes.TrimSpace(data)) if err != nil { - return CertificateLocation{}, fmt.Errorf("file %s: %w", fileName, err) + return CertificateLocation{Path: fmt.Sprintf("%s: %v", fileName, err)} } return CertificateLocation{ Path: fileName, Certificates: certificates, - }, nil + } } func nameFormat(name string, tlsVersion uint16) string { diff --git a/pkg/cert/location_test.go b/pkg/cert/location_test.go index 11d11aa..3a96620 100644 --- a/pkg/cert/location_test.go +++ b/pkg/cert/location_test.go @@ -30,15 +30,17 @@ func Test_nameFormat(t *testing.T) { func Test_loadCertificate(t *testing.T) { t.Run("given valid certificate then cert location is loaded", func(t *testing.T) { certificate := loadTestFile(t, "cert.pem") - _, err := loadCertificate("test", certificate) - require.NoError(t, err) + cert := loadCertificate("test", certificate) + require.Equal(t, 1, len(cert.Certificates)) + assert.Equal(t, "CN=DigiCert Global Root G2,OU=www.digicert.com,O=DigiCert Inc,C=US", cert.Certificates[0].SubjectString()) }) t.Run("given certificate with extra spaces then cert location is loaded", func(t *testing.T) { certificate := loadTestFile(t, "cert.pem") certificate = bytes.Join([][]byte{[]byte(" "), certificate}, []byte("")) - _, err := loadCertificate("test", certificate) - require.NoError(t, err) + cert := loadCertificate("test", certificate) + require.Equal(t, 1, len(cert.Certificates)) + assert.Equal(t, "CN=DigiCert Global Root G2,OU=www.digicert.com,O=DigiCert Inc,C=US", cert.Certificates[0].SubjectString()) }) } @@ -51,7 +53,8 @@ func Test_loadCertificateFromClipboard(t *testing.T) { certificate := loadTestFile(t, "cert.pem") clipboard.Write(clipboard.FmtText, certificate) - _, err := LoadCertificateFromClipboard() - require.NoError(t, err) + cert := LoadCertificateFromClipboard() + require.Equal(t, 1, len(cert.Certificates)) + assert.Equal(t, "CN=DigiCert Global Root G2,OU=www.digicert.com,O=DigiCert Inc,C=US", cert.Certificates[0].SubjectString()) }) } diff --git a/print.go b/print.go index 41fb63f..2b1a399 100644 --- a/print.go +++ b/print.go @@ -57,6 +57,10 @@ func PrintCertificatesExpiry(certificateLocations []cert.CertificateLocation) { for _, certificateLocation := range certificateLocations { fmt.Printf("--- [%s] ---\n", certificateLocation.Name()) + if len(certificateLocation.Certificates) == 0 { + // in case of error (no certificates), print new line + fmt.Println() + } for _, certificate := range certificateLocation.Certificates { fmt.Printf("Subject: %s\n", certificate.SubjectString())