diff --git a/association.go b/association.go index f0f1d7be..507e3aad 100644 --- a/association.go +++ b/association.go @@ -264,6 +264,7 @@ func createClientWithContext(ctx context.Context, config Config) (*Association, select { case <-ctx.Done(): a.log.Errorf("[%s] client handshake canceled: state=%s", a.name, getAssociationStateString(a.getState())) + a.Close() // nolint:errcheck,gosec return nil, ctx.Err() case err := <-a.handshakeCompletedCh: if err != nil { diff --git a/association_test.go b/association_test.go index d7d71150..38108233 100644 --- a/association_test.go +++ b/association_test.go @@ -2997,3 +2997,64 @@ func TestAssociation_Abort(t *testing.T) { assert.Equal(t, i, 0, "expected no data read") assert.Error(t, err, "User Initiated Abort: 1234", "expected abort reason") } + +// TestAssociation_createClientWithContext tests that the client is closed when the context is canceled. +func TestAssociation_createClientWithContext(t *testing.T) { + checkGoroutineLeaks(t) + + udp1, udp2, err := createUDPConnPair(t) + require.NoError(t, err) + + loggerFactory := logging.NewDefaultLoggerFactory() + + errCh1 := make(chan error) + errCh2 := make(chan error) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + + go func() { + _, err2 := createClientWithContext(ctx, Config{ + NetConn: udp1, + LoggerFactory: loggerFactory, + }) + if err2 != nil { + errCh1 <- err2 + } else { + errCh1 <- nil + } + }() + + go func() { + _, err2 := createClientWithContext(ctx, Config{ + NetConn: udp2, + LoggerFactory: loggerFactory, + }) + if err2 != nil { + errCh2 <- err2 + } else { + errCh2 <- nil + } + }() + + // Cancel the context immediately + cancel() + + var err1 error + var err2 error +loop: + for { + select { + case err1 = <-errCh1: + if err1 != nil && err2 != nil { + break loop + } + case err2 = <-errCh2: + if err1 != nil && err2 != nil { + break loop + } + } + } + + assert.Error(t, err1, "context canceled") + assert.Error(t, err2, "context canceled") +}