Skip to content
This repository has been archived by the owner on Oct 2, 2022. It is now read-only.

Commit

Permalink
Code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Janos Pasztor committed Nov 11, 2020
1 parent f5b8055 commit 292feea
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 72 deletions.
2 changes: 1 addition & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
type Client interface {
// Post queries the configured endpoint with the path, sending the requestBody and providing the
// response in the responseBody structure. It returns the HTTP status code and any potential errors.
Post(path string, requestBody interface{}, responseBody interface{}) (int, error)
Post(path string, requestBody interface{}, responseBody interface{}) (statusCode int, err error)
}

// ClientConfiguration is the configuration structure for HTTP clients
Expand Down
1 change: 1 addition & 0 deletions client_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func (c *client) request(
if err != nil {
return 0, err
}
defer func() { _ = resp.Body.Close() }()

decoder := json.NewDecoder(resp.Body)
decoder.DisallowUnknownFields()
Expand Down
8 changes: 6 additions & 2 deletions handler_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (h *handler) ServeHTTP(goWriter goHttp.ResponseWriter, goRequest *goHttp.Re
&response,
); err != nil {
if errors.Is(err, &badRequestResponse) {
response = *err.(*serverResponse)
response = badRequestResponse
} else {
h.logger.Warningf("handler returned error response (%w)", err)
response = internalErrorResponse
Expand All @@ -64,7 +64,11 @@ func (h *handler) ServeHTTP(goWriter goHttp.ResponseWriter, goRequest *goHttp.Re
bytes, err := json.Marshal(response.body)
if err != nil {
h.logger.Errorf("failed to marshal response %v (%w)", response, err)
bytes, _ = json.Marshal(internalErrorResponse)
response = internalErrorResponse
bytes, err = json.Marshal(internalErrorResponse.body)
if err != nil {
panic(fmt.Errorf("bug: failed to marshal internal server error JSON response (%w)", err))
}
}
goWriter.WriteHeader(int(response.statusCode))
goWriter.Header().Add("Content-Type", "application/json")
Expand Down
45 changes: 13 additions & 32 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"fmt"
"math/big"
"net"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -262,11 +261,11 @@ func createCA() (*rsa.PrivateKey, *x509.Certificate, []byte, error) {
}
caPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to create private key (%v)", err)
return nil, nil, nil, fmt.Errorf("failed to create private key (%w)", err)
}
caCert, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivateKey.PublicKey, caPrivateKey)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to create CA certificate (%v)", err)
return nil, nil, nil, fmt.Errorf("failed to create CA certificate (%w)", err)
}
caPEM := new(bytes.Buffer)
if err := pem.Encode(
Expand All @@ -276,7 +275,7 @@ func createCA() (*rsa.PrivateKey, *x509.Certificate, []byte, error) {
Bytes: caCert,
},
); err != nil {
return nil, nil, nil, fmt.Errorf("failed to encode CA cert (%v)", err)
return nil, nil, nil, fmt.Errorf("failed to encode CA cert (%w)", err)
}
return caPrivateKey, ca, caPEM.Bytes(), nil
}
Expand Down Expand Up @@ -355,42 +354,24 @@ func runRequest(
}

errorChannel := make(chan error, 2)
clientDone := make(chan bool, 1)
responseStatus := 0
wg := &sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
if err := server.Run(); err != nil {
errorChannel <- err
}
close(errorChannel)
}()
<-ready
go func() {
defer wg.Done()
if responseStatus, err = client.Post(
"",
&Request{Message: message},
&response,
); err != nil {
errorChannel <- err
}
clientDone <- true
}()
<-clientDone
if responseStatus, err = client.Post(
"",
&Request{Message: message},
&response,
); err != nil {
errorChannel <- err
}
server.Shutdown(context.Background())
wg.Wait()
finished := false
for {
select {
case err := <-errorChannel:
return response, 0, err
default:
finished = true
}
if finished {
break
}
if err, ok := <-errorChannel; ok {
return response, 0, err
}
return response, responseStatus, nil
}
81 changes: 45 additions & 36 deletions server_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,43 +23,10 @@ func NewServer(

var tlsConfig *tls.Config
if config.Cert != "" && config.Key != "" {
tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS13,
CurvePreferences: []tls.CurveID{tls.X25519, tls.CurveP521, tls.CurveP384, tls.CurveP256},
PreferServerCipherSuites: true,
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_AES_128_GCM_SHA256,
tls.TLS_AES_256_GCM_SHA384,
tls.TLS_CHACHA20_POLY1305_SHA256,
},
}

clientCert, err := loadPem(config.Cert)
if err != nil {
return nil, fmt.Errorf("failed to load client certificate (%w)", err)
}
clientKey, err := loadPem(config.Key)
if err != nil {
return nil, fmt.Errorf("failed to load client certificate (%w)", err)
}
cert, err := tls.X509KeyPair(clientCert, clientKey)
var err error
tlsConfig, err = createServerTlsConfig(config)
if err != nil {
return nil, fmt.Errorf("failed to load certificate or key (%w)", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}

if config.ClientCaCert != "" {
clientCaCert, err := loadPem(config.ClientCaCert)
if err != nil {
return nil, fmt.Errorf("failed to load CA certificate (%w)", err)
}

caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(clientCaCert)
tlsConfig.ClientCAs = caCertPool
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
return nil, err
}
} else {
if config.Cert != "" {
Expand All @@ -84,3 +51,45 @@ func NewServer(
onReady: onReady,
}, nil
}

func createServerTlsConfig(config ServerConfiguration) (*tls.Config, error) {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS13,
CurvePreferences: []tls.CurveID{tls.X25519, tls.CurveP521, tls.CurveP384, tls.CurveP256},
PreferServerCipherSuites: true,
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_AES_128_GCM_SHA256,
tls.TLS_AES_256_GCM_SHA384,
tls.TLS_CHACHA20_POLY1305_SHA256,
},
}

clientCert, err := loadPem(config.Cert)
if err != nil {
return nil, fmt.Errorf("failed to load client certificate (%w)", err)
}
clientKey, err := loadPem(config.Key)
if err != nil {
return nil, fmt.Errorf("failed to load client certificate (%w)", err)
}
cert, err := tls.X509KeyPair(clientCert, clientKey)
if err != nil {
return nil, fmt.Errorf("failed to load certificate or key (%w)", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}

if config.ClientCaCert != "" {
clientCaCert, err := loadPem(config.ClientCaCert)
if err != nil {
return nil, fmt.Errorf("failed to load CA certificate (%w)", err)
}

caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(clientCaCert)
tlsConfig.ClientCAs = caCertPool
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
return tlsConfig, nil
}
3 changes: 2 additions & 1 deletion server_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package http
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -56,7 +57,7 @@ func (s *server) Run() error {
} else {
err = s.srv.Serve(ln)
}
if err != nil && err != goHttp.ErrServerClosed {
if err != nil && !errors.Is(err, goHttp.ErrServerClosed) {
return err
}
return nil
Expand Down

0 comments on commit 292feea

Please sign in to comment.