From 698aab1d431273664d0fd3c374e9c340443fed02 Mon Sep 17 00:00:00 2001 From: Aleksey Myasnikov Date: Mon, 23 Dec 2024 23:02:30 +0300 Subject: [PATCH] * Fixed goroutine leak on failed execute call in query client --- CHANGELOG.md | 2 ++ internal/query/execute_query.go | 12 +++++++-- internal/query/result.go | 44 +++++++++++++++++++++++---------- internal/query/result_test.go | 2 +- 4 files changed, 44 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 75cd71df7..1ef91f105 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Fixed goroutine leak on failed execute call in query client + ## v3.95.4 * Fixed connections pool leak on closing sessions * Fixed an error in logging session deletion events diff --git a/internal/query/execute_query.go b/internal/query/execute_query.go index 36c5cfcbd..577d7f8d7 100644 --- a/internal/query/execute_query.go +++ b/internal/query/execute_query.go @@ -126,14 +126,22 @@ func execute( return nil, xerrors.WithStackTrace(err) } - executeCtx := xcontext.ValueOnly(ctx) + executeCtx, executeCancel := xcontext.WithCancel(xcontext.ValueOnly(ctx)) + defer func() { + if finalErr != nil { + executeCancel() + } + }() stream, err := c.ExecuteQuery(executeCtx, request, callOptions...) if err != nil { return nil, xerrors.WithStackTrace(err) } - r, err := newResult(ctx, stream, append(opts, withStatsCallback(settings.StatsCallback()))...) + r, err := newResult(ctx, stream, append(opts, + withStatsCallback(settings.StatsCallback()), + withOnClose(executeCancel), + )...) if err != nil { return nil, xerrors.WithStackTrace(err) } diff --git a/internal/query/result.go b/internal/query/result.go index ede423443..750710829 100644 --- a/internal/query/result.go +++ b/internal/query/result.go @@ -33,12 +33,13 @@ type ( } streamResult struct { stream Ydb_Query_V1.QueryService_ExecuteQueryClient - closeOnce func() + close func() lastPart *Ydb_Query.ExecuteQueryResponsePart resultSetIndex int64 closed chan struct{} trace *trace.Query statsCallback func(queryStats stats.QueryStats) + onClose []func() onNextPartErr []func(err error) onTxMeta []func(txMeta *Ydb_Query.TransactionMeta) } @@ -98,6 +99,12 @@ func withStatsCallback(callback func(queryStats stats.QueryStats)) resultOption } } +func withOnClose(onClose func()) resultOption { + return func(s *streamResult) { + s.onClose = append(s.onClose, onClose) + } +} + func onNextPartErr(callback func(err error)) resultOption { return func(s *streamResult) { s.onNextPartErr = append(s.onNextPartErr, callback) @@ -115,15 +122,19 @@ func newResult( stream Ydb_Query_V1.QueryService_ExecuteQueryClient, opts ...resultOption, ) (_ *streamResult, finalErr error) { - r := streamResult{ - stream: stream, - closed: make(chan struct{}), - resultSetIndex: -1, - } - r.closeOnce = sync.OnceFunc(func() { - close(r.closed) - r.stream = nil - }) + var ( + closed = make(chan struct{}) + r = streamResult{ + stream: stream, + onClose: []func(){ + func() { + close(closed) + }, + }, + closed: closed, + resultSetIndex: -1, + } + ) for _, opt := range opts { if opt != nil { @@ -131,6 +142,13 @@ func newResult( } } + r.close = sync.OnceFunc(func() { + for _, onClose := range r.onClose { + onClose() + } + r.stream = nil + }) + if r.trace != nil { onDone := trace.QueryOnResultNew(r.trace, &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.newResult"), @@ -177,7 +195,7 @@ func (r *streamResult) nextPart(ctx context.Context) ( default: part, err = nextPart(r.stream) if err != nil { - r.closeOnce() + r.close() for _, callback := range r.onNextPartErr { callback(err) @@ -208,7 +226,7 @@ func nextPart(stream Ydb_Query_V1.QueryService_ExecuteQueryClient) ( } func (r *streamResult) Close(ctx context.Context) (finalErr error) { - defer r.closeOnce() + defer r.close() if r.trace != nil { onDone := trace.QueryOnResultClose(r.trace, &ctx, @@ -261,7 +279,7 @@ func (r *streamResult) nextResultSet(ctx context.Context) (_ *resultSet, err err r.statsCallback(stats.FromQueryStats(part.GetExecStats())) } if part.GetResultSetIndex() < r.resultSetIndex { - r.closeOnce() + r.close() if part.GetResultSetIndex() <= 0 && r.resultSetIndex > 0 { return nil, xerrors.WithStackTrace(io.EOF) } diff --git a/internal/query/result_test.go b/internal/query/result_test.go index 98bfb41da..5baee9000 100644 --- a/internal/query/result_test.go +++ b/internal/query/result_test.go @@ -539,7 +539,7 @@ func TestResultNextResultSet(t *testing.T) { require.EqualValues(t, 1, rs.rowIndex) } t.Log("explicit interrupt stream") - r.closeOnce() + r.close() { t.Log("next (row=3)") _, err := rs.nextRow(context.Background())