Skip to content

Commit

Permalink
execute args in go routines
Browse files Browse the repository at this point in the history
  • Loading branch information
pete911 committed Aug 2, 2024
1 parent b1e05d7 commit 1e4f1ae
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 53 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ certinfo [flags] [<file>|<host:port> ...]
+---------------+---------------------------------------------------------------------------------------------------+
```

If you need to run against multiple hosts, it is faster to execute command with multiple arguments e.g.
`certinfo -insecure -expiry google.com:443 amazon.com:443 ...` rather than executing command multiple times. Args are
executed concurrently and much faster.

Flags can be set as env. variable as well (`CERTINFO_<FLAG>=true` e.g. `CERTINFO_INSECURE=true`) and can be then
overridden with a flag.

Expand Down
60 changes: 28 additions & 32 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"strconv"
"strings"
"sync"
)

var Version = "dev"
Expand Down Expand Up @@ -45,40 +46,15 @@ func LoadCertificatesLocations(flags Flags) cert.CertificateLocations {

var certificateLocations cert.CertificateLocations
if flags.Clipboard {
certificateLocation, err := cert.LoadCertificateFromClipboard()
if err != nil {
printCertFileError("clipboard", err)
return nil
}
certificateLocations = append(certificateLocations, certificateLocation)
certificateLocations = append(certificateLocations, cert.LoadCertificateFromClipboard())
}

if len(flags.Args) > 0 {
for _, arg := range flags.Args {

var certificateLocation cert.CertificateLocation
var err error
if isTCPNetworkAddress(arg) {
certificateLocation, err = cert.LoadCertificatesFromNetwork(arg, flags.Insecure)
} else {
certificateLocation, err = cert.LoadCertificatesFromFile(arg)
}

if err != nil {
printCertFileError(arg, err)
continue
}
certificateLocations = append(certificateLocations, certificateLocation)
}
certificateLocations = append(certificateLocations, loadFromArgs(flags.Args, flags.Insecure)...)
}

if isStdin() {
certificateLocation, err := cert.LoadCertificateFromStdin()
if err != nil {
printCertFileError("stdin", err)
return nil
}
certificateLocations = append(certificateLocations, certificateLocation)
certificateLocations = append(certificateLocations, cert.LoadCertificateFromStdin())
}

if len(certificateLocations) > 0 {
Expand All @@ -91,11 +67,31 @@ func LoadCertificatesLocations(flags Flags) cert.CertificateLocations {
return nil
}

func printCertFileError(fileName string, err error) {
func loadFromArgs(args []string, insecure bool) cert.CertificateLocations {

out := make(chan cert.CertificateLocation)
go func() {
var wg sync.WaitGroup
for _, arg := range args {
wg.Add(1)
go func() {
defer wg.Done()
if isTCPNetworkAddress(arg) {
out <- cert.LoadCertificatesFromNetwork(arg, insecure)
return
}
out <- cert.LoadCertificatesFromFile(arg)
}()
}
wg.Wait()
close(out)
}()

fmt.Printf("--- [%s] ---\n", fileName)
fmt.Println(err)
fmt.Println()
var certificateLocations cert.CertificateLocations
for location := range out {
certificateLocations = append(certificateLocations, location)
}
return certificateLocations
}

func isTCPNetworkAddress(arg string) bool {
Expand Down
28 changes: 13 additions & 15 deletions pkg/cert/location.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package cert
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"golang.design/x/clipboard"
"io"
Expand Down Expand Up @@ -53,11 +52,11 @@ func (c CertificateLocation) RemoveDuplicates() CertificateLocation {
return c
}

func LoadCertificatesFromNetwork(addr string, tlsSkipVerify bool) (CertificateLocation, error) {
func LoadCertificatesFromNetwork(addr string, tlsSkipVerify bool) CertificateLocation {

conn, err := tls.DialWithDialer(&net.Dialer{Timeout: tlsDialTimeout}, "tcp", addr, &tls.Config{InsecureSkipVerify: tlsSkipVerify})
if err != nil {
return CertificateLocation{}, fmt.Errorf("tcp connection failed: %w", err)
return CertificateLocation{Path: fmt.Sprintf("%s: %v", addr, err)}
}

connectionState := conn.ConnectionState()
Expand All @@ -73,52 +72,51 @@ func LoadCertificatesFromNetwork(addr string, tlsSkipVerify bool) (CertificateLo
Path: addr,
Certificates: FromX509Certificates(x509Certificates),
VerifiedChains: verifiedChains,
}, nil
}
}

func LoadCertificatesFromFile(fileName string) (CertificateLocation, error) {
func LoadCertificatesFromFile(fileName string) CertificateLocation {

b, err := os.ReadFile(fileName)
if err != nil {
return CertificateLocation{}, fmt.Errorf("skipping %s file: %w", fileName, err)
return CertificateLocation{Path: fmt.Sprintf("%s: %v", fileName, err)}
}
return loadCertificate(fileName, b)
}

func LoadCertificateFromStdin() (CertificateLocation, error) {
func LoadCertificateFromStdin() CertificateLocation {

content, err := io.ReadAll(os.Stdin)
if err != nil {
return CertificateLocation{}, fmt.Errorf("reading stdin: %w", err)
return CertificateLocation{Path: fmt.Sprintf("stdin: %v", err)}
}
return loadCertificate("stdin", content)
}

func LoadCertificateFromClipboard() (CertificateLocation, error) {
func LoadCertificateFromClipboard() CertificateLocation {

if err := clipboard.Init(); err != nil {
return CertificateLocation{}, fmt.Errorf("unable to load from clipboard: %w", err)
return CertificateLocation{Path: fmt.Sprintf("clipboard: %v", err)}
}

content := clipboard.Read(clipboard.FmtText)
if content == nil {
return CertificateLocation{}, errors.New("clipboard is empty")
return CertificateLocation{Path: "clipboard is empty"}
}

return loadCertificate("clipboard", content)
}

func loadCertificate(fileName string, data []byte) (CertificateLocation, error) {
func loadCertificate(fileName string, data []byte) CertificateLocation {

certificates, err := FromBytes(bytes.TrimSpace(data))
if err != nil {
return CertificateLocation{}, fmt.Errorf("file %s: %w", fileName, err)
return CertificateLocation{Path: fmt.Sprintf("%s: %v", fileName, err)}
}

return CertificateLocation{
Path: fileName,
Certificates: certificates,
}, nil
}
}

func nameFormat(name string, tlsVersion uint16) string {
Expand Down
15 changes: 9 additions & 6 deletions pkg/cert/location_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,17 @@ func Test_nameFormat(t *testing.T) {
func Test_loadCertificate(t *testing.T) {
t.Run("given valid certificate then cert location is loaded", func(t *testing.T) {
certificate := loadTestFile(t, "cert.pem")
_, err := loadCertificate("test", certificate)
require.NoError(t, err)
cert := loadCertificate("test", certificate)
require.Equal(t, 1, len(cert.Certificates))
assert.Equal(t, "CN=DigiCert Global Root G2,OU=www.digicert.com,O=DigiCert Inc,C=US", cert.Certificates[0].SubjectString())
})

t.Run("given certificate with extra spaces then cert location is loaded", func(t *testing.T) {
certificate := loadTestFile(t, "cert.pem")
certificate = bytes.Join([][]byte{[]byte(" "), certificate}, []byte(""))
_, err := loadCertificate("test", certificate)
require.NoError(t, err)
cert := loadCertificate("test", certificate)
require.Equal(t, 1, len(cert.Certificates))
assert.Equal(t, "CN=DigiCert Global Root G2,OU=www.digicert.com,O=DigiCert Inc,C=US", cert.Certificates[0].SubjectString())
})
}

Expand All @@ -51,7 +53,8 @@ func Test_loadCertificateFromClipboard(t *testing.T) {
certificate := loadTestFile(t, "cert.pem")
clipboard.Write(clipboard.FmtText, certificate)

_, err := LoadCertificateFromClipboard()
require.NoError(t, err)
cert := LoadCertificateFromClipboard()
require.Equal(t, 1, len(cert.Certificates))
assert.Equal(t, "CN=DigiCert Global Root G2,OU=www.digicert.com,O=DigiCert Inc,C=US", cert.Certificates[0].SubjectString())
})
}
4 changes: 4 additions & 0 deletions print.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ func PrintCertificatesExpiry(certificateLocations []cert.CertificateLocation) {

for _, certificateLocation := range certificateLocations {
fmt.Printf("--- [%s] ---\n", certificateLocation.Name())
if len(certificateLocation.Certificates) == 0 {
// in case of error (no certificates), print new line
fmt.Println()
}
for _, certificate := range certificateLocation.Certificates {

fmt.Printf("Subject: %s\n", certificate.SubjectString())
Expand Down

0 comments on commit 1e4f1ae

Please sign in to comment.