diff --git a/client.go b/client.go index efda859..ae8dbde 100644 --- a/client.go +++ b/client.go @@ -9,7 +9,7 @@ import ( type Client interface { // Post queries the configured endpoint with the path, sending the requestBody and providing the // response in the responseBody structure. It returns the HTTP status code and any potential errors. - Post(path string, requestBody interface{}, responseBody interface{}) (int, error) + Post(path string, requestBody interface{}, responseBody interface{}) (statusCode int, err error) } // ClientConfiguration is the configuration structure for HTTP clients diff --git a/client_impl.go b/client_impl.go index 49244db..93da421 100644 --- a/client_impl.go +++ b/client_impl.go @@ -52,6 +52,7 @@ func (c *client) request( if err != nil { return 0, err } + defer func() { _ = resp.Body.Close() }() decoder := json.NewDecoder(resp.Body) decoder.DisallowUnknownFields() diff --git a/handler_impl.go b/handler_impl.go index 4798883..79c5706 100644 --- a/handler_impl.go +++ b/handler_impl.go @@ -55,7 +55,7 @@ func (h *handler) ServeHTTP(goWriter goHttp.ResponseWriter, goRequest *goHttp.Re &response, ); err != nil { if errors.Is(err, &badRequestResponse) { - response = *err.(*serverResponse) + response = badRequestResponse } else { h.logger.Warningf("handler returned error response (%w)", err) response = internalErrorResponse @@ -64,7 +64,11 @@ func (h *handler) ServeHTTP(goWriter goHttp.ResponseWriter, goRequest *goHttp.Re bytes, err := json.Marshal(response.body) if err != nil { h.logger.Errorf("failed to marshal response %v (%w)", response, err) - bytes, _ = json.Marshal(internalErrorResponse) + response = internalErrorResponse + bytes, err = json.Marshal(internalErrorResponse.body) + if err != nil { + panic(fmt.Errorf("bug: failed to marshal internal server error JSON response (%w)", err)) + } } goWriter.WriteHeader(int(response.statusCode)) goWriter.Header().Add("Content-Type", "application/json") diff --git a/integration_test.go b/integration_test.go index 0a6f3c8..20f8927 100644 --- a/integration_test.go +++ b/integration_test.go @@ -11,7 +11,6 @@ import ( "fmt" "math/big" "net" - "sync" "testing" "time" @@ -262,11 +261,11 @@ func createCA() (*rsa.PrivateKey, *x509.Certificate, []byte, error) { } caPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return nil, nil, nil, fmt.Errorf("failed to create private key (%v)", err) + return nil, nil, nil, fmt.Errorf("failed to create private key (%w)", err) } caCert, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivateKey.PublicKey, caPrivateKey) if err != nil { - return nil, nil, nil, fmt.Errorf("failed to create CA certificate (%v)", err) + return nil, nil, nil, fmt.Errorf("failed to create CA certificate (%w)", err) } caPEM := new(bytes.Buffer) if err := pem.Encode( @@ -276,7 +275,7 @@ func createCA() (*rsa.PrivateKey, *x509.Certificate, []byte, error) { Bytes: caCert, }, ); err != nil { - return nil, nil, nil, fmt.Errorf("failed to encode CA cert (%v)", err) + return nil, nil, nil, fmt.Errorf("failed to encode CA cert (%w)", err) } return caPrivateKey, ca, caPEM.Bytes(), nil } @@ -355,42 +354,24 @@ func runRequest( } errorChannel := make(chan error, 2) - clientDone := make(chan bool, 1) responseStatus := 0 - wg := &sync.WaitGroup{} - wg.Add(2) go func() { - defer wg.Done() if err := server.Run(); err != nil { errorChannel <- err } + close(errorChannel) }() <-ready - go func() { - defer wg.Done() - if responseStatus, err = client.Post( - "", - &Request{Message: message}, - &response, - ); err != nil { - errorChannel <- err - } - clientDone <- true - }() - <-clientDone + if responseStatus, err = client.Post( + "", + &Request{Message: message}, + &response, + ); err != nil { + errorChannel <- err + } server.Shutdown(context.Background()) - wg.Wait() - finished := false - for { - select { - case err := <-errorChannel: - return response, 0, err - default: - finished = true - } - if finished { - break - } + if err, ok := <-errorChannel; ok { + return response, 0, err } return response, responseStatus, nil } diff --git a/server_factory.go b/server_factory.go index de955cf..0e895bf 100644 --- a/server_factory.go +++ b/server_factory.go @@ -23,43 +23,10 @@ func NewServer( var tlsConfig *tls.Config if config.Cert != "" && config.Key != "" { - tlsConfig = &tls.Config{ - MinVersion: tls.VersionTLS13, - CurvePreferences: []tls.CurveID{tls.X25519, tls.CurveP521, tls.CurveP384, tls.CurveP256}, - PreferServerCipherSuites: true, - CipherSuites: []uint16{ - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_AES_128_GCM_SHA256, - tls.TLS_AES_256_GCM_SHA384, - tls.TLS_CHACHA20_POLY1305_SHA256, - }, - } - - clientCert, err := loadPem(config.Cert) - if err != nil { - return nil, fmt.Errorf("failed to load client certificate (%w)", err) - } - clientKey, err := loadPem(config.Key) - if err != nil { - return nil, fmt.Errorf("failed to load client certificate (%w)", err) - } - cert, err := tls.X509KeyPair(clientCert, clientKey) + var err error + tlsConfig, err = createServerTlsConfig(config) if err != nil { - return nil, fmt.Errorf("failed to load certificate or key (%w)", err) - } - tlsConfig.Certificates = []tls.Certificate{cert} - - if config.ClientCaCert != "" { - clientCaCert, err := loadPem(config.ClientCaCert) - if err != nil { - return nil, fmt.Errorf("failed to load CA certificate (%w)", err) - } - - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(clientCaCert) - tlsConfig.ClientCAs = caCertPool - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + return nil, err } } else { if config.Cert != "" { @@ -84,3 +51,45 @@ func NewServer( onReady: onReady, }, nil } + +func createServerTlsConfig(config ServerConfiguration) (*tls.Config, error) { + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS13, + CurvePreferences: []tls.CurveID{tls.X25519, tls.CurveP521, tls.CurveP384, tls.CurveP256}, + PreferServerCipherSuites: true, + CipherSuites: []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + tls.TLS_CHACHA20_POLY1305_SHA256, + }, + } + + clientCert, err := loadPem(config.Cert) + if err != nil { + return nil, fmt.Errorf("failed to load client certificate (%w)", err) + } + clientKey, err := loadPem(config.Key) + if err != nil { + return nil, fmt.Errorf("failed to load client certificate (%w)", err) + } + cert, err := tls.X509KeyPair(clientCert, clientKey) + if err != nil { + return nil, fmt.Errorf("failed to load certificate or key (%w)", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + + if config.ClientCaCert != "" { + clientCaCert, err := loadPem(config.ClientCaCert) + if err != nil { + return nil, fmt.Errorf("failed to load CA certificate (%w)", err) + } + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(clientCaCert) + tlsConfig.ClientCAs = caCertPool + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + return tlsConfig, nil +} diff --git a/server_impl.go b/server_impl.go index 9445d00..7de0609 100644 --- a/server_impl.go +++ b/server_impl.go @@ -3,6 +3,7 @@ package http import ( "context" "crypto/tls" + "errors" "fmt" "io" "log" @@ -56,7 +57,7 @@ func (s *server) Run() error { } else { err = s.srv.Serve(ln) } - if err != nil && err != goHttp.ErrServerClosed { + if err != nil && !errors.Is(err, goHttp.ErrServerClosed) { return err } return nil