Skip to content

Commit

Permalink
added sort by expiry flag
Browse files Browse the repository at this point in the history
  • Loading branch information
pete911 committed Aug 2, 2024
1 parent 1e4f1ae commit 4ca345b
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 12 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ certinfo [flags] [<file>|<host:port> ...]
| -no-expired | do not print expired certificates |
| -pem | whether to print pem as well |
| -pem-only | whether to print only pem (useful for downloading certs from host) |
| -sort-expiry | sort certificates by expiration date |
| -version | certinfo version |
| -help | help |
+---------------+---------------------------------------------------------------------------------------------------+
Expand Down
3 changes: 3 additions & 0 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type Flags struct {
Expiry bool
NoDuplicate bool
NoExpired bool
SortExpiry bool
Insecure bool
Chains bool
Pem bool
Expand All @@ -33,6 +34,8 @@ func ParseFlags() (Flags, error) {
"do not print duplicate certificates")
flagSet.BoolVar(&flags.NoExpired, "no-expired", getBoolEnv("CERTINFO_NO_EXPIRED", false),
"do not print expired certificates")
flagSet.BoolVar(&flags.SortExpiry, "sort-expiry", getBoolEnv("CERTINFO_SORT_EXPIRY", false),
"sort certificates by expiration date")
flagSet.BoolVar(&flags.Insecure, "insecure", getBoolEnv("CERTINFO_INSECURE", false),
"whether a client verifies the server's certificate chain and host name (only applicable for host)")
flagSet.BoolVar(&flags.Chains, "chains", getBoolEnv("CERTINFO_CHAINS", false),
Expand Down
16 changes: 13 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ func main() {
if flags.NoDuplicate {
certificatesFiles = certificatesFiles.RemoveDuplicates()
}
if flags.SortExpiry {
certificatesFiles = certificatesFiles.SortByExpiry()
}
if flags.Expiry {
PrintCertificatesExpiry(certificatesFiles)
return
Expand Down Expand Up @@ -87,11 +90,18 @@ func loadFromArgs(args []string, insecure bool) cert.CertificateLocations {
close(out)
}()

var certificateLocations cert.CertificateLocations
// load certificates from the channel
certsByArgs := make(map[string]cert.CertificateLocation)
for location := range out {
certificateLocations = append(certificateLocations, location)
certsByArgs[location.Path] = location
}

// sort certificates by input arguments
var certsSortedByArgs cert.CertificateLocations
for _, arg := range args {
certsSortedByArgs = append(certsSortedByArgs, certsByArgs[arg])
}
return certificateLocations
return certsSortedByArgs
}

func isTCPNetworkAddress(arg string) bool {
Expand Down
8 changes: 8 additions & 0 deletions pkg/cert/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"github.com/icza/gox/timex"
"slices"
"strings"
"time"
)
Expand Down Expand Up @@ -39,6 +40,13 @@ func (c Certificates) RemoveDuplicates() Certificates {
return out
}

func (c Certificates) SortByExpiry() Certificates {
slices.SortFunc(c, func(a, b Certificate) int {
return a.x509Certificate.NotAfter.Compare(b.x509Certificate.NotAfter)
})
return c
}

type Certificate struct {
// position of certificate in the chain, starts with 1
position int
Expand Down
21 changes: 21 additions & 0 deletions pkg/cert/cert_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package cert

import (
"crypto/x509"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"strings"
"testing"
"time"
)

func TestFromBytes(t *testing.T) {
Expand Down Expand Up @@ -34,6 +36,25 @@ func TestCertificates_RemoveDuplicates(t *testing.T) {
})
}

func TestCertificates_SortByExpiry(t *testing.T) {
t.Run("given multiple certificates, when they have different expiry, then they are sorted", func(t *testing.T) {
certificates := Certificates{
// using version to validate tests
{x509Certificate: &x509.Certificate{NotAfter: time.Now().AddDate(0, 6, 3), Version: 1}},
{x509Certificate: &x509.Certificate{NotAfter: time.Now().AddDate(1, 6, 2), Version: 3}},
{x509Certificate: &x509.Certificate{NotAfter: time.Now().AddDate(1, 6, 21), Version: 4}},
{x509Certificate: &x509.Certificate{NotAfter: time.Now().AddDate(1, 3, 3), Version: 2}},
}

sortedCertificates := certificates.SortByExpiry()
require.Equal(t, 4, len(sortedCertificates))
assert.Equal(t, 1, sortedCertificates[0].x509Certificate.Version)
assert.Equal(t, 2, sortedCertificates[1].x509Certificate.Version)
assert.Equal(t, 3, sortedCertificates[2].x509Certificate.Version)
assert.Equal(t, 4, sortedCertificates[3].x509Certificate.Version)
})
}

func Test_expiryFormat(t *testing.T) {
t.Run("given certificate expiry is more than a year then year is returned as well", func(t *testing.T) {
v := expiryFormat(getTime(3, 2, 7, 5, 25))
Expand Down
43 changes: 37 additions & 6 deletions pkg/cert/location.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package cert
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"golang.design/x/clipboard"
"io"
"net"
"os"
"slices"
"time"
)

Expand All @@ -31,9 +33,33 @@ func (c CertificateLocations) RemoveDuplicates() CertificateLocations {
return out
}

func (c CertificateLocations) SortByExpiry() CertificateLocations {
var out CertificateLocations
// sort certificates in every location
for i := range c {
out = append(out, c[i].SortByExpiry())
}

// sort locations by first certificate (they have been already sorted)
slices.SortFunc(out, func(a, b CertificateLocation) int {
if len(a.Certificates) == 0 && len(b.Certificates) == 0 {
return 0
}
if len(a.Certificates) == 0 {
return 1
}
if len(b.Certificates) == 0 {
return -1
}
return a.Certificates[0].x509Certificate.NotAfter.Compare(b.Certificates[0].x509Certificate.NotAfter)
})
return out
}

type CertificateLocation struct {
TLSVersion uint16 // only applicable for network certificates
Path string
Error error
Certificates Certificates
VerifiedChains []Certificates // only applicable for network certificates
}
Expand All @@ -52,11 +78,16 @@ func (c CertificateLocation) RemoveDuplicates() CertificateLocation {
return c
}

func (c CertificateLocation) SortByExpiry() CertificateLocation {
c.Certificates = c.Certificates.SortByExpiry()
return c
}

func LoadCertificatesFromNetwork(addr string, tlsSkipVerify bool) CertificateLocation {

conn, err := tls.DialWithDialer(&net.Dialer{Timeout: tlsDialTimeout}, "tcp", addr, &tls.Config{InsecureSkipVerify: tlsSkipVerify})
if err != nil {
return CertificateLocation{Path: fmt.Sprintf("%s: %v", addr, err)}
return CertificateLocation{Path: addr, Error: err}
}

connectionState := conn.ConnectionState()
Expand All @@ -79,7 +110,7 @@ func LoadCertificatesFromFile(fileName string) CertificateLocation {

b, err := os.ReadFile(fileName)
if err != nil {
return CertificateLocation{Path: fmt.Sprintf("%s: %v", fileName, err)}
return CertificateLocation{Path: fileName, Error: err}
}
return loadCertificate(fileName, b)
}
Expand All @@ -88,20 +119,20 @@ func LoadCertificateFromStdin() CertificateLocation {

content, err := io.ReadAll(os.Stdin)
if err != nil {
return CertificateLocation{Path: fmt.Sprintf("stdin: %v", err)}
return CertificateLocation{Path: "stdin", Error: err}
}
return loadCertificate("stdin", content)
}

func LoadCertificateFromClipboard() CertificateLocation {

if err := clipboard.Init(); err != nil {
return CertificateLocation{Path: fmt.Sprintf("clipboard: %v", err)}
return CertificateLocation{Path: "clipboard", Error: err}
}

content := clipboard.Read(clipboard.FmtText)
if content == nil {
return CertificateLocation{Path: "clipboard is empty"}
return CertificateLocation{Path: "clipboard", Error: errors.New("clipboard is empty")}
}
return loadCertificate("clipboard", content)
}
Expand All @@ -110,7 +141,7 @@ func loadCertificate(fileName string, data []byte) CertificateLocation {

certificates, err := FromBytes(bytes.TrimSpace(data))
if err != nil {
return CertificateLocation{Path: fmt.Sprintf("%s: %v", fileName, err)}
return CertificateLocation{Path: fileName, Error: err}
}

return CertificateLocation{
Expand Down
40 changes: 40 additions & 0 deletions pkg/cert/location_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package cert
import (
"bytes"
"crypto/tls"
"crypto/x509"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -58,3 +60,41 @@ func Test_loadCertificateFromClipboard(t *testing.T) {
assert.Equal(t, "CN=DigiCert Global Root G2,OU=www.digicert.com,O=DigiCert Inc,C=US", cert.Certificates[0].SubjectString())
})
}

func TestCertificateLocation_SortByExpiry(t *testing.T) {
t.Run("given valid certificate in clipboard then cert is loaded", func(t *testing.T) {
locations := CertificateLocations{
{
Path: "three",
Certificates: Certificates{
{x509Certificate: &x509.Certificate{NotAfter: time.Now().AddDate(3, 2, 3)}},
},
},
{
Path: "one",
Certificates: Certificates{
{x509Certificate: &x509.Certificate{NotAfter: time.Now().AddDate(1, 6, 2)}},
{x509Certificate: &x509.Certificate{NotAfter: time.Now().AddDate(1, 6, 21)}},
{x509Certificate: &x509.Certificate{NotAfter: time.Now().AddDate(0, 6, 3)}},
{x509Certificate: &x509.Certificate{NotAfter: time.Now().AddDate(1, 3, 3)}},
},
},
{
Path: "four",
},
{
Path: "two",
Certificates: Certificates{
{x509Certificate: &x509.Certificate{NotAfter: time.Now().AddDate(0, 7, 3)}},
},
},
}

sortedLocations := locations.SortByExpiry()
require.Equal(t, 4, len(sortedLocations))
assert.Equal(t, "one", sortedLocations[0].Path)
assert.Equal(t, "two", sortedLocations[1].Path)
assert.Equal(t, "three", sortedLocations[2].Path)
assert.Equal(t, "four", sortedLocations[3].Path)
})
}
14 changes: 11 additions & 3 deletions print.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ import (
func PrintCertificatesLocations(certificateLocations []cert.CertificateLocation, printChains, printPem bool) {

for _, certificateLocation := range certificateLocations {
if certificateLocation.Error != nil {
fmt.Printf("--- [%s: %v] ---\n", certificateLocation.Name(), certificateLocation.Error)
fmt.Println()
continue
}

fmt.Printf("--- [%s] ---\n", certificateLocation.Name())
printCertificates(certificateLocation.Certificates, printPem)

Expand Down Expand Up @@ -56,11 +62,13 @@ func PrintPemOnly(certificateLocations []cert.CertificateLocation, printChains b
func PrintCertificatesExpiry(certificateLocations []cert.CertificateLocation) {

for _, certificateLocation := range certificateLocations {
fmt.Printf("--- [%s] ---\n", certificateLocation.Name())
if len(certificateLocation.Certificates) == 0 {
// in case of error (no certificates), print new line
if certificateLocation.Error != nil {
fmt.Printf("--- [%s: %v] ---\n", certificateLocation.Name(), certificateLocation.Error)
fmt.Println()
continue
}

fmt.Printf("--- [%s] ---\n", certificateLocation.Name())
for _, certificate := range certificateLocation.Certificates {

fmt.Printf("Subject: %s\n", certificate.SubjectString())
Expand Down

0 comments on commit 4ca345b

Please sign in to comment.