From 38a0423c85602c3fb2468a4c902971b337d53983 Mon Sep 17 00:00:00 2001 From: Aleksey Myasnikov Date: Wed, 3 Jul 2024 13:15:00 +0300 Subject: [PATCH] refactoring of internal/pool --- balancers/balancers.go | 20 +- config/config.go | 7 - driver.go | 79 ++-- internal/balancer/balancer.go | 207 +++++------ internal/balancer/balancer_test.go | 124 ------- internal/balancer/config/routerconfig.go | 4 +- internal/balancer/connections_state.go | 170 +++++---- internal/balancer/connections_state_test.go | 46 +-- internal/balancer/local_dc_test.go | 2 +- internal/conn/config.go | 4 - internal/conn/conn.go | 354 ++++-------------- internal/conn/conn_test.go | 2 +- internal/conn/error.go | 25 -- internal/conn/error_test.go | 79 ---- internal/conn/grpc_client_stream.go | 103 +++--- internal/conn/pool.go | 253 ------------- internal/endpoint/diff.go | 6 +- internal/endpoint/diff_test.go | 2 +- internal/pool/defaults.go | 8 - internal/pool/pool.go | 383 ++++++-------------- internal/pool/pool_test.go | 71 +++- internal/pool/{stats => }/stats.go | 3 +- internal/pool/trace.go | 27 +- internal/query/client.go | 48 +-- internal/query/session.go | 4 +- internal/xerrors/operation.go | 6 + internal/xerrors/transport.go | 20 + internal/xmath/xmath.go | 8 +- internal/xsync/map.go | 3 +- internal/xsync/map_test.go | 34 +- internal/xsync/once.go | 20 +- log/query.go | 52 --- metrics/query.go | 2 - options.go | 15 +- query/client.go | 2 +- query/stats.go | 4 +- sugar/query_test.go | 6 +- trace/driver.go | 26 +- trace/driver_gtrace.go | 38 +- trace/query.go | 31 -- trace/query_gtrace.go | 127 +------ with.go | 2 +- 42 files changed, 711 insertions(+), 1716 deletions(-) delete mode 100644 internal/conn/error.go delete mode 100644 internal/conn/error_test.go delete mode 100644 internal/conn/pool.go rename internal/pool/{stats => }/stats.go (69%) diff --git a/balancers/balancers.go b/balancers/balancers.go index d8f856874..56bd5604a 100644 --- a/balancers/balancers.go +++ b/balancers/balancers.go @@ -5,7 +5,7 @@ import ( "strings" balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xslices" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring" ) @@ -29,8 +29,8 @@ func SingleConn() *balancerConfig.Config { type filterLocalDC struct{} -func (filterLocalDC) Allow(info balancerConfig.Info, c conn.Conn) bool { - return c.Endpoint().Location() == info.SelfLocation +func (filterLocalDC) Allow(info balancerConfig.Info, e endpoint.Info) bool { + return e.Location() == info.SelfLocation } func (filterLocalDC) String() string { @@ -59,8 +59,8 @@ func PreferLocalDCWithFallBack(balancer *balancerConfig.Config) *balancerConfig. type filterLocations []string -func (locations filterLocations) Allow(_ balancerConfig.Info, c conn.Conn) bool { - location := strings.ToUpper(c.Endpoint().Location()) +func (locations filterLocations) Allow(_ balancerConfig.Info, e endpoint.Info) bool { + location := strings.ToUpper(e.Location()) for _, l := range locations { if location == l { return true @@ -127,10 +127,10 @@ type Endpoint interface { LocalDC() bool } -type filterFunc func(info balancerConfig.Info, c conn.Conn) bool +type filterFunc func(info balancerConfig.Info, e endpoint.Info) bool -func (p filterFunc) Allow(info balancerConfig.Info, c conn.Conn) bool { - return p(info, c) +func (p filterFunc) Allow(info balancerConfig.Info, e endpoint.Info) bool { + return p(info, e) } func (p filterFunc) String() string { @@ -140,8 +140,8 @@ func (p filterFunc) String() string { // Prefer creates balancer which use endpoints by filter // Balancer "balancer" defines balancing algorithm between endpoints selected with filter func Prefer(balancer *balancerConfig.Config, filter func(endpoint Endpoint) bool) *balancerConfig.Config { - balancer.Filter = filterFunc(func(_ balancerConfig.Info, c conn.Conn) bool { - return filter(c.Endpoint()) + balancer.Filter = filterFunc(func(_ balancerConfig.Info, e endpoint.Info) bool { + return filter(e) }) return balancer diff --git a/config/config.go b/config/config.go index 370dd283d..1fc748086 100644 --- a/config/config.go +++ b/config/config.go @@ -58,13 +58,6 @@ func (c *Config) Meta() *meta.Meta { return c.meta } -// ConnectionTTL defines interval for parking grpc connections. -// -// If ConnectionTTL is zero - connections are not park. -func (c *Config) ConnectionTTL() time.Duration { - return c.connectionTTL -} - // Secure is a flag for secure connection func (c *Config) Secure() bool { return c.secure diff --git a/driver.go b/driver.go index 4710c4d22..14d51b873 100644 --- a/driver.go +++ b/driver.go @@ -20,7 +20,6 @@ import ( internalDiscovery "github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery" discoveryConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery/config" "github.com/ydb-platform/ydb-go-sdk/v3/internal/dsn" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" internalQuery "github.com/ydb-platform/ydb-go-sdk/v3/internal/query" queryConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/config" internalRatelimiter "github.com/ydb-platform/ydb-go-sdk/v3/internal/ratelimiter" @@ -51,8 +50,6 @@ var _ Connection = (*Driver)(nil) // Driver type provide access to YDB service clients type Driver struct { - ctxCancel context.CancelFunc - userInfo *dsn.UserInfo logger log.Logger @@ -90,7 +87,7 @@ type Driver struct { databaseSQLOptions []xsql.ConnectorOption - pool *conn.Pool + // pool *conn.Pool mtx sync.Mutex balancer *balancer.Balancer @@ -120,13 +117,10 @@ func (d *Driver) Close(ctx context.Context) (finalErr error) { defer func() { onDone(finalErr) }() - d.ctxCancel() d.mtx.Lock() defer d.mtx.Unlock() - d.ctxCancel() - defer func() { for _, f := range d.onClose { f(d) @@ -151,7 +145,7 @@ func (d *Driver) Close(ctx context.Context) (finalErr error) { d.query.Close, d.topic.Close, d.balancer.Close, - d.pool.Release, + // d.pool.Release, ) var issues []error @@ -185,44 +179,44 @@ func (d *Driver) Secure() bool { // Table returns table client func (d *Driver) Table() table.Client { - return d.table.Get() + return d.table.Must() } // Query returns query client // // Experimental: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#experimental func (d *Driver) Query() *internalQuery.Client { - return d.query.Get() + return d.query.Must() } // Scheme returns scheme client func (d *Driver) Scheme() scheme.Client { - return d.scheme.Get() + return d.scheme.Must() } // Coordination returns coordination client func (d *Driver) Coordination() coordination.Client { - return d.coordination.Get() + return d.coordination.Must() } // Ratelimiter returns ratelimiter client func (d *Driver) Ratelimiter() ratelimiter.Client { - return d.ratelimiter.Get() + return d.ratelimiter.Must() } // Discovery returns discovery client func (d *Driver) Discovery() discovery.Client { - return d.discovery.Get() + return d.discovery.Must() } // Scripting returns scripting client func (d *Driver) Scripting() scripting.Client { - return d.scripting.Get() + return d.scripting.Must() } // Topic returns topic client func (d *Driver) Topic() topic.Client { - return d.topic.Get() + return d.topic.Must() } // Open connects to database by DSN and return driver runtime holder @@ -308,16 +302,11 @@ func New(ctx context.Context, opts ...Option) (_ *Driver, err error) { //nolint: //nolint:cyclop, nonamedreturns, funlen func newConnectionFromOptions(ctx context.Context, opts ...Option) (_ *Driver, err error) { - ctx, driverCtxCancel := xcontext.WithCancel(xcontext.ValueOnly(ctx)) - defer func() { - if err != nil { - driverCtxCancel() - } - }() + ctx, cancel := xcontext.WithCancel(xcontext.ValueOnly(ctx)) + defer cancel() d := &Driver{ - children: make(map[uint64]*Driver), - ctxCancel: driverCtxCancel, + children: make(map[uint64]*Driver), } if caFile, has := os.LookupEnv("YDB_SSL_ROOT_CERTIFICATES_FILE"); has { @@ -398,16 +387,16 @@ func (d *Driver) connect(ctx context.Context) (err error) { )) } - if d.pool == nil { - d.pool = conn.NewPool(ctx, d.config) - } + //if d.pool == nil { + // d.pool = conn.NewPool(ctx, d.config) + //} - d.balancer, err = balancer.New(ctx, d.config, d.pool, d.discoveryOptions...) + d.balancer, err = balancer.New(ctx, d.config /*d.pool,*/, d.discoveryOptions...) if err != nil { return xerrors.WithStackTrace(err) } - d.table = xsync.OnceValue(func() *internalTable.Client { + d.table = xsync.OnceValue(func() (*internalTable.Client, error) { return internalTable.New(xcontext.ValueOnly(ctx), d.balancer, tableConfig.New( @@ -419,10 +408,10 @@ func (d *Driver) connect(ctx context.Context) (err error) { d.tableOptions..., )..., ), - ) + ), nil }) - d.query = xsync.OnceValue(func() *internalQuery.Client { + d.query = xsync.OnceValue(func() (*internalQuery.Client, error) { return internalQuery.New(xcontext.ValueOnly(ctx), d.balancer, queryConfig.New( @@ -434,13 +423,13 @@ func (d *Driver) connect(ctx context.Context) (err error) { d.queryOptions..., )..., ), - ) + ), nil }) if err != nil { return xerrors.WithStackTrace(err) } - d.scheme = xsync.OnceValue(func() *internalScheme.Client { + d.scheme = xsync.OnceValue(func() (*internalScheme.Client, error) { return internalScheme.New(xcontext.ValueOnly(ctx), d.balancer, schemeConfig.New( @@ -453,10 +442,10 @@ func (d *Driver) connect(ctx context.Context) (err error) { d.schemeOptions..., )..., ), - ) + ), nil }) - d.coordination = xsync.OnceValue(func() *internalCoordination.Client { + d.coordination = xsync.OnceValue(func() (*internalCoordination.Client, error) { return internalCoordination.New(xcontext.ValueOnly(ctx), d.balancer, coordinationConfig.New( @@ -468,10 +457,10 @@ func (d *Driver) connect(ctx context.Context) (err error) { d.coordinationOptions..., )..., ), - ) + ), nil }) - d.ratelimiter = xsync.OnceValue(func() *internalRatelimiter.Client { + d.ratelimiter = xsync.OnceValue(func() (*internalRatelimiter.Client, error) { return internalRatelimiter.New(xcontext.ValueOnly(ctx), d.balancer, ratelimiterConfig.New( @@ -483,12 +472,12 @@ func (d *Driver) connect(ctx context.Context) (err error) { d.ratelimiterOptions..., )..., ), - ) + ), nil }) - d.discovery = xsync.OnceValue(func() *internalDiscovery.Client { + d.discovery = xsync.OnceValue(func() (*internalDiscovery.Client, error) { return internalDiscovery.New(xcontext.ValueOnly(ctx), - d.pool.Get(endpoint.New(d.config.Endpoint())), + d.balancer, discoveryConfig.New( append( // prepend common params from root config @@ -502,10 +491,10 @@ func (d *Driver) connect(ctx context.Context) (err error) { d.discoveryOptions..., )..., ), - ) + ), nil }) - d.scripting = xsync.OnceValue(func() *internalScripting.Client { + d.scripting = xsync.OnceValue(func() (*internalScripting.Client, error) { return internalScripting.New(xcontext.ValueOnly(ctx), d.balancer, scriptingConfig.New( @@ -517,10 +506,10 @@ func (d *Driver) connect(ctx context.Context) (err error) { d.scriptingOptions..., )..., ), - ) + ), nil }) - d.topic = xsync.OnceValue(func() *topicclientinternal.Client { + d.topic = xsync.OnceValue(func() (*topicclientinternal.Client, error) { return topicclientinternal.New(xcontext.ValueOnly(ctx), d.balancer, d.config.Credentials(), @@ -532,7 +521,7 @@ func (d *Driver) connect(ctx context.Context) (err error) { }, d.topicOptions..., )..., - ) + ), nil }) return nil diff --git a/internal/balancer/balancer.go b/internal/balancer/balancer.go index 2be99f006..6e952a1c9 100644 --- a/internal/balancer/balancer.go +++ b/internal/balancer/balancer.go @@ -3,9 +3,10 @@ package balancer import ( "context" "fmt" - "sort" + "sync/atomic" "google.golang.org/grpc" + grpcCodes "google.golang.org/grpc/codes" "github.com/ydb-platform/ydb-go-sdk/v3/config" balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" @@ -35,21 +36,24 @@ type discoveryClient interface { type Balancer struct { driverConfig *config.Config config balancerConfig.Config - pool *conn.Pool discoveryClient discoveryClient discoveryRepeater repeater.Repeater localDCDetector func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error) - mu xsync.RWMutex - connectionsState *connectionsState + conns xsync.Map[string, *xsync.Once[conn.Conn]] + banned xsync.Map[string, struct{}] - onApplyDiscoveredEndpoints []func(ctx context.Context, endpoints []endpoint.Info) + state atomic.Pointer[connectionsState] + + onApplyDiscoveredEndpointsMtx xsync.RWMutex + onApplyDiscoveredEndpoints []func(ctx context.Context, endpoints []endpoint.Info) } func (b *Balancer) OnUpdate(onApplyDiscoveredEndpoints func(ctx context.Context, endpoints []endpoint.Info)) { - b.mu.WithLock(func() { - b.onApplyDiscoveredEndpoints = append(b.onApplyDiscoveredEndpoints, onApplyDiscoveredEndpoints) - }) + b.onApplyDiscoveredEndpointsMtx.RLock() + defer b.onApplyDiscoveredEndpointsMtx.RUnlock() + + b.onApplyDiscoveredEndpoints = append(b.onApplyDiscoveredEndpoints, onApplyDiscoveredEndpoints) } func (b *Balancer) clusterDiscovery(ctx context.Context) (err error) { @@ -121,41 +125,13 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) { return nil } -func endpointsDiff(newestEndpoints []endpoint.Endpoint, previousConns []conn.Conn) ( - nodes []trace.EndpointInfo, - added []trace.EndpointInfo, - dropped []trace.EndpointInfo, -) { - nodes = make([]trace.EndpointInfo, 0, len(newestEndpoints)) - added = make([]trace.EndpointInfo, 0, len(previousConns)) - dropped = make([]trace.EndpointInfo, 0, len(previousConns)) - var ( - newestMap = make(map[string]struct{}, len(newestEndpoints)) - previousMap = make(map[string]struct{}, len(previousConns)) - ) - sort.Slice(newestEndpoints, func(i, j int) bool { - return newestEndpoints[i].Address() < newestEndpoints[j].Address() - }) - sort.Slice(previousConns, func(i, j int) bool { - return previousConns[i].Endpoint().Address() < previousConns[j].Endpoint().Address() - }) - for _, e := range previousConns { - previousMap[e.Endpoint().Address()] = struct{}{} - } - for _, e := range newestEndpoints { - nodes = append(nodes, e.Copy()) - newestMap[e.Address()] = struct{}{} - if _, has := previousMap[e.Address()]; !has { - added = append(added, e.Copy()) - } - } - for _, c := range previousConns { - if _, has := newestMap[c.Endpoint().Address()]; !has { - dropped = append(dropped, c.Endpoint().Copy()) - } +func s2s[T1, T2 any](in []T1, f func(T1) T2) (out []T2) { + out = make([]T2, len(in)) + for i := range in { + out[i] = f(in[i]) } - return nodes, added, dropped + return out } func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []endpoint.Endpoint, localDC string) { @@ -166,36 +142,76 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end "github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints"), b.config.DetectLocalDC, ) - previousConns []conn.Conn ) - defer func() { - nodes, added, dropped := endpointsDiff(endpoints, previousConns) - onDone(nodes, added, dropped, localDC) - }() + state := newConnectionsState( + endpoints, + b.config.Filter, + balancerConfig.Info{SelfLocation: localDC}, + b.config.AllowFallback, + func(e endpoint.Endpoint) bool { + return !b.banned.Has(e.Address()) + }, + ) + + _, added, dropped := endpoint.Diff(endpoints, b.state.Swap(state).All()) - connections := endpointsToConnections(b.pool, endpoints) - for _, c := range connections { - b.pool.Allow(ctx, c) - c.Endpoint().Touch() + for _, endpoint := range dropped { + if cc, ok := b.conns.LoadAndDelete(endpoint.Address()); ok { + _ = cc.Close(ctx) + } } - info := balancerConfig.Info{SelfLocation: localDC} - state := newConnectionsState(connections, b.config.Filter, info, b.config.AllowFallback) + for _, endpoint := range added { + b.conns.Store(endpoint.Address(), xsync.OnceValue[conn.Conn](func() (conn.Conn, error) { + cc, err := conn.New(endpoint, b.driverConfig, + conn.WithOnTransportError(func(ctx context.Context, cc conn.Conn, cause error) { + if xerrors.IsTransportError(cause, + //grpcCodes.OK, + //grpcCodes.ResourceExhausted, + //grpcCodes.Unavailable, + grpcCodes.Canceled, + grpcCodes.Unknown, + grpcCodes.InvalidArgument, + grpcCodes.DeadlineExceeded, + grpcCodes.NotFound, + grpcCodes.AlreadyExists, + grpcCodes.PermissionDenied, + grpcCodes.FailedPrecondition, + grpcCodes.Aborted, + grpcCodes.OutOfRange, + grpcCodes.Unimplemented, + grpcCodes.Internal, + grpcCodes.DataLoss, + grpcCodes.Unauthenticated, + ) { + b.banned.Store(cc.Endpoint().Address(), struct{}{}) + } + }), + ) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } - endpointsInfo := make([]endpoint.Info, len(endpoints)) - for i, e := range endpoints { - endpointsInfo[i] = e + return cc, nil + })) } - b.mu.WithLock(func() { - if b.connectionsState != nil { - previousConns = b.connectionsState.all - } - b.connectionsState = state + b.banned.Clear() + + infos := s2s(endpoints, func(e endpoint.Endpoint) endpoint.Info { return e }) + + b.onApplyDiscoveredEndpointsMtx.WithRLock(func() { for _, onApplyDiscoveredEndpoints := range b.onApplyDiscoveredEndpoints { - onApplyDiscoveredEndpoints(ctx, endpointsInfo) + onApplyDiscoveredEndpoints(ctx, infos) } }) + + onDone( + s2s(endpoints, func(e endpoint.Endpoint) trace.EndpointInfo { return e }), + s2s(added, func(e endpoint.Endpoint) trace.EndpointInfo { return e }), + s2s(dropped, func(e endpoint.Endpoint) trace.EndpointInfo { return e }), + localDC, + ) } func (b *Balancer) Close(ctx context.Context) (err error) { @@ -221,7 +237,6 @@ func (b *Balancer) Close(ctx context.Context) (err error) { func New( ctx context.Context, driverConfig *config.Config, - pool *conn.Pool, opts ...discoveryConfig.Option, ) (b *Balancer, finalErr error) { var ( @@ -244,14 +259,14 @@ func New( b = &Balancer{ driverConfig: driverConfig, - pool: pool, localDCDetector: detectLocalDC, } - d := internalDiscovery.New(ctx, pool.Get( - endpoint.New(driverConfig.Endpoint()), - ), discoveryConfig) + cc, err := conn.New(endpoint.New(driverConfig.Endpoint()), driverConfig) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } - b.discoveryClient = d + b.discoveryClient = internalDiscovery.New(ctx, cc, discoveryConfig) if config := driverConfig.Balancer(); config == nil { b.config = balancerConfig.Config{} @@ -318,16 +333,6 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc return xerrors.WithStackTrace(err) } - defer func() { - if err == nil { - if cc.GetState() == conn.Banned { - b.pool.Allow(ctx, cc) - } - } else if xerrors.MustPessimizeEndpoint(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) { - b.pool.Ban(ctx, cc, err) - } - }() - if ctx, err = b.driverConfig.Meta().Context(ctx); err != nil { return xerrors.WithStackTrace(err) } @@ -351,21 +356,14 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc return nil } -func (b *Balancer) connections() *connectionsState { - b.mu.RLock() - defer b.mu.RUnlock() - - return b.connectionsState -} - func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) { onDone := trace.DriverOnBalancerChooseEndpoint( b.driverConfig.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).getConn"), ) defer func() { - if err == nil { - onDone(c.Endpoint(), nil) + if c != nil { + onDone(c.Endpoint(), err) } else { onDone(nil, err) } @@ -375,32 +373,27 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) { return nil, xerrors.WithStackTrace(err) } - var ( - state = b.connections() - failedCount int - ) + state := b.state.Load() defer func() { - if failedCount*2 > state.PreferredCount() && b.discoveryRepeater != nil { + if err != nil || (len(state.all)*2 < len(b.state.Load().all) && b.discoveryRepeater != nil) { b.discoveryRepeater.Force() } }() - c, failedCount = state.GetConnection(ctx) - if c == nil { - return nil, xerrors.WithStackTrace( - fmt.Errorf("%w: cannot get connection from Balancer after %d attempts", ErrNoEndpoints, failedCount), - ) - } - - return c, nil -} - -func endpointsToConnections(p *conn.Pool, endpoints []endpoint.Endpoint) []conn.Conn { - conns := make([]conn.Conn, 0, len(endpoints)) - for _, e := range endpoints { - conns = append(conns, p.Get(e)) + for i := 0; ; i++ { + e := state.Next(ctx) + if e == nil { + return nil, xerrors.WithStackTrace( + fmt.Errorf("%w: cannot get connection from Balancer after %d attempts", ErrNoEndpoints, i+1), + ) + } + cc := b.conns.Must(e.Address()) + c, err = cc.Get() + if err == nil { + return c, nil + } + b.banned.Store(e.Address(), struct{}{}) + state = state.exclude(e) } - - return conns } diff --git a/internal/balancer/balancer_test.go b/internal/balancer/balancer_test.go index 356952f38..db2249d36 100644 --- a/internal/balancer/balancer_test.go +++ b/internal/balancer/balancer_test.go @@ -1,125 +1 @@ package balancer - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/mock" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest" - "github.com/ydb-platform/ydb-go-sdk/v3/trace" -) - -func TestEndpointsDiff(t *testing.T) { - for _, tt := range []struct { - newestEndpoints []endpoint.Endpoint - previousConns []conn.Conn - nodes []trace.EndpointInfo - added []trace.EndpointInfo - dropped []trace.EndpointInfo - }{ - { - newestEndpoints: []endpoint.Endpoint{ - &mock.Endpoint{AddrField: "1"}, - &mock.Endpoint{AddrField: "3"}, - &mock.Endpoint{AddrField: "2"}, - &mock.Endpoint{AddrField: "0"}, - }, - previousConns: []conn.Conn{ - &mock.Conn{AddrField: "2"}, - &mock.Conn{AddrField: "1"}, - &mock.Conn{AddrField: "0"}, - &mock.Conn{AddrField: "3"}, - }, - nodes: []trace.EndpointInfo{ - &mock.Endpoint{AddrField: "0"}, - &mock.Endpoint{AddrField: "1"}, - &mock.Endpoint{AddrField: "2"}, - &mock.Endpoint{AddrField: "3"}, - }, - added: []trace.EndpointInfo{}, - dropped: []trace.EndpointInfo{}, - }, - { - newestEndpoints: []endpoint.Endpoint{ - &mock.Endpoint{AddrField: "1"}, - &mock.Endpoint{AddrField: "3"}, - &mock.Endpoint{AddrField: "2"}, - &mock.Endpoint{AddrField: "0"}, - }, - previousConns: []conn.Conn{ - &mock.Conn{AddrField: "1"}, - &mock.Conn{AddrField: "0"}, - &mock.Conn{AddrField: "3"}, - }, - nodes: []trace.EndpointInfo{ - &mock.Endpoint{AddrField: "0"}, - &mock.Endpoint{AddrField: "1"}, - &mock.Endpoint{AddrField: "2"}, - &mock.Endpoint{AddrField: "3"}, - }, - added: []trace.EndpointInfo{ - &mock.Endpoint{AddrField: "2"}, - }, - dropped: []trace.EndpointInfo{}, - }, - { - newestEndpoints: []endpoint.Endpoint{ - &mock.Endpoint{AddrField: "1"}, - &mock.Endpoint{AddrField: "3"}, - &mock.Endpoint{AddrField: "0"}, - }, - previousConns: []conn.Conn{ - &mock.Conn{AddrField: "1"}, - &mock.Conn{AddrField: "2"}, - &mock.Conn{AddrField: "0"}, - &mock.Conn{AddrField: "3"}, - }, - nodes: []trace.EndpointInfo{ - &mock.Endpoint{AddrField: "0"}, - &mock.Endpoint{AddrField: "1"}, - &mock.Endpoint{AddrField: "3"}, - }, - added: []trace.EndpointInfo{}, - dropped: []trace.EndpointInfo{ - &mock.Endpoint{AddrField: "2"}, - }, - }, - { - newestEndpoints: []endpoint.Endpoint{ - &mock.Endpoint{AddrField: "1"}, - &mock.Endpoint{AddrField: "3"}, - &mock.Endpoint{AddrField: "0"}, - }, - previousConns: []conn.Conn{ - &mock.Conn{AddrField: "4"}, - &mock.Conn{AddrField: "7"}, - &mock.Conn{AddrField: "8"}, - }, - nodes: []trace.EndpointInfo{ - &mock.Endpoint{AddrField: "0"}, - &mock.Endpoint{AddrField: "1"}, - &mock.Endpoint{AddrField: "3"}, - }, - added: []trace.EndpointInfo{ - &mock.Endpoint{AddrField: "0"}, - &mock.Endpoint{AddrField: "1"}, - &mock.Endpoint{AddrField: "3"}, - }, - dropped: []trace.EndpointInfo{ - &mock.Endpoint{AddrField: "4"}, - &mock.Endpoint{AddrField: "7"}, - &mock.Endpoint{AddrField: "8"}, - }, - }, - } { - t.Run(xtest.CurrentFileLine(), func(t *testing.T) { - nodes, added, dropped := endpointsDiff(tt.newestEndpoints, tt.previousConns) - require.Equal(t, tt.nodes, nodes) - require.Equal(t, tt.added, added) - require.Equal(t, tt.dropped, dropped) - }) - } -} diff --git a/internal/balancer/config/routerconfig.go b/internal/balancer/config/routerconfig.go index 0d1eb6703..220926c20 100644 --- a/internal/balancer/config/routerconfig.go +++ b/internal/balancer/config/routerconfig.go @@ -3,7 +3,7 @@ package config import ( "fmt" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring" ) @@ -47,6 +47,6 @@ type Info struct { } type Filter interface { - Allow(info Info, c conn.Conn) bool + Allow(info Info, e endpoint.Info) bool String() string } diff --git a/internal/balancer/connections_state.go b/internal/balancer/connections_state.go index b9f2c9043..101393e1a 100644 --- a/internal/balancer/connections_state.go +++ b/internal/balancer/connections_state.go @@ -4,146 +4,183 @@ import ( "context" balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xrand" ) type connectionsState struct { - connByNodeID map[uint32]conn.Conn + index map[uint32]endpoint.Endpoint - prefer []conn.Conn - fallback []conn.Conn - all []conn.Conn + checkEndpoint func(endpoint.Endpoint) bool + + prefer []endpoint.Endpoint + fallback []endpoint.Endpoint + all []endpoint.Endpoint rand xrand.Rand } func newConnectionsState( - conns []conn.Conn, + endpoints []endpoint.Endpoint, filter balancerConfig.Filter, info balancerConfig.Info, allowFallback bool, + checkEndpoint func(endpoint.Endpoint) bool, ) *connectionsState { - res := &connectionsState{ - connByNodeID: connsToNodeIDMap(conns), - rand: xrand.New(xrand.WithLock()), + s := &connectionsState{ + index: endpointsToNodeIDMap(endpoints), + rand: xrand.New(xrand.WithLock()), + checkEndpoint: checkEndpoint, } - res.prefer, res.fallback = sortPreferConnections(conns, filter, info, allowFallback) + s.prefer, s.fallback = sortPreferEndpoints(endpoints, filter, info, allowFallback) if allowFallback { - res.all = conns + s.all = endpoints } else { - res.all = res.prefer + s.all = s.prefer } - return res + return s } func (s *connectionsState) PreferredCount() int { return len(s.prefer) } -func (s *connectionsState) GetConnection(ctx context.Context) (_ conn.Conn, failedCount int) { - if err := ctx.Err(); err != nil { - return nil, 0 - } - - if c := s.preferConnection(ctx); c != nil { - return c, 0 +func (s *connectionsState) All() []endpoint.Endpoint { + if s == nil { + return nil } - try := func(conns []conn.Conn) conn.Conn { - c, tryFailed := s.selectRandomConnection(conns, false) - failedCount += tryFailed + return s.all +} - return c +func (s *connectionsState) Next(ctx context.Context) endpoint.Endpoint { + if err := ctx.Err(); err != nil { + return nil } - if c := try(s.prefer); c != nil { - return c, failedCount + if e := s.preferEndpoint(ctx); e != nil { + return e } - if c := try(s.fallback); c != nil { - return c, failedCount + if e := s.selectRandomEndpoint(s.prefer); e != nil { + return e } - c, _ := s.selectRandomConnection(s.all, true) + if e := s.selectRandomEndpoint(s.fallback); e != nil { + return e + } - return c, failedCount + return s.selectRandomEndpoint(s.all) } -func (s *connectionsState) preferConnection(ctx context.Context) conn.Conn { - if nodeID, hasPreferEndpoint := endpoint.ContextNodeID(ctx); hasPreferEndpoint { - c := s.connByNodeID[nodeID] - if c != nil && isOkConnection(c, true) { - return c +func (s *connectionsState) preferEndpoint(ctx context.Context) endpoint.Endpoint { + if nodeID, has := endpoint.ContextNodeID(ctx); has { + e := s.index[nodeID] + if e != nil && s.checkEndpoint(e) { + return e } } return nil } -func (s *connectionsState) selectRandomConnection(conns []conn.Conn, allowBanned bool) (c conn.Conn, failedConns int) { - connCount := len(conns) - if connCount == 0 { +func (s *connectionsState) selectRandomEndpoint(endpoints []endpoint.Endpoint) endpoint.Endpoint { + count := len(endpoints) + if count == 0 { // return for empty list need for prevent panic in fast path - return nil, 0 + return nil } // fast path - if c := conns[s.rand.Int(connCount)]; isOkConnection(c, allowBanned) { - return c, 0 + if e := endpoints[s.rand.Int(count)]; s.checkEndpoint(e) { + return e } // shuffled indexes slices need for guarantee about every connection will check - indexes := make([]int, connCount) + indexes := make([]int, count) for index := range indexes { indexes[index] = index } - s.rand.Shuffle(connCount, func(i, j int) { + + s.rand.Shuffle(count, func(i, j int) { indexes[i], indexes[j] = indexes[j], indexes[i] }) for _, index := range indexes { - c := conns[index] - if isOkConnection(c, allowBanned) { - return c, 0 + e := endpoints[index] + if s.checkEndpoint(e) { + return e + } + } + + return nil +} + +func excludeS(in []endpoint.Endpoint, exclude endpoint.Endpoint) (out []endpoint.Endpoint) { + out = make([]endpoint.Endpoint, 0, len(in)) + + for i := range in { + if in[i].Address() != exclude.Address() { + out = append(out, in[i]) } - failedConns++ } - return nil, failedConns + return out } -func connsToNodeIDMap(conns []conn.Conn) (nodes map[uint32]conn.Conn) { - if len(conns) == 0 { +func excludeM(in map[uint32]endpoint.Endpoint, exclude endpoint.Endpoint) (out map[uint32]endpoint.Endpoint) { + out = make(map[uint32]endpoint.Endpoint, len(in)) + + for i := range in { + if in[i].Address() != exclude.Address() { + out[in[i].NodeID()] = in[i] + } + } + + return out +} + +func (s *connectionsState) exclude(e endpoint.Endpoint) *connectionsState { + return &connectionsState{ + index: excludeM(s.index, e), + checkEndpoint: s.checkEndpoint, + prefer: excludeS(s.prefer, e), + fallback: excludeS(s.fallback, e), + all: excludeS(s.all, e), + rand: s.rand, + } +} + +func endpointsToNodeIDMap(endpoints []endpoint.Endpoint) (index map[uint32]endpoint.Endpoint) { + if len(endpoints) == 0 { return nil } - nodes = make(map[uint32]conn.Conn, len(conns)) - for _, c := range conns { - nodes[c.Endpoint().NodeID()] = c + index = make(map[uint32]endpoint.Endpoint, len(endpoints)) + for _, c := range endpoints { + index[c.NodeID()] = c } - return nodes + return index } -func sortPreferConnections( - conns []conn.Conn, +func sortPreferEndpoints( + endpoints []endpoint.Endpoint, filter balancerConfig.Filter, info balancerConfig.Info, allowFallback bool, -) (prefer, fallback []conn.Conn) { +) (prefer, fallback []endpoint.Endpoint) { if filter == nil { - return conns, nil + return endpoints, nil } - prefer = make([]conn.Conn, 0, len(conns)) + prefer = make([]endpoint.Endpoint, 0, len(endpoints)) if allowFallback { - fallback = make([]conn.Conn, 0, len(conns)) + fallback = make([]endpoint.Endpoint, 0, len(endpoints)) } - for _, c := range conns { + for _, c := range endpoints { if filter.Allow(info, c) { prefer = append(prefer, c) } else if allowFallback { @@ -153,14 +190,3 @@ func sortPreferConnections( return prefer, fallback } - -func isOkConnection(c conn.Conn, bannedIsOk bool) bool { - switch c.GetState() { - case conn.Online, conn.Created, conn.Offline: - return true - case conn.Banned: - return bannedIsOk - default: - return false - } -} diff --git a/internal/balancer/connections_state_test.go b/internal/balancer/connections_state_test.go index 2ab07df84..fa77a81e7 100644 --- a/internal/balancer/connections_state_test.go +++ b/internal/balancer/connections_state_test.go @@ -150,7 +150,7 @@ func TestSortPreferConnections(t *testing.T) { for _, test := range table { t.Run(test.name, func(t *testing.T) { - prefer, fallback := sortPreferConnections(test.source, test.filter, balancerConfig.Info{}, test.allowFallback) + prefer, fallback := sortPreferEndpoints(test.source, test.filter, balancerConfig.Info{}, test.allowFallback) require.Equal(t, test.prefer, prefer) require.Equal(t, test.fallback, fallback) }) @@ -161,24 +161,24 @@ func TestSelectRandomConnection(t *testing.T) { s := newConnectionsState(nil, nil, balancerConfig.Info{}, false) t.Run("Empty", func(t *testing.T) { - c, failedCount := s.selectRandomConnection(nil, false) + c, failedCount := s.selectRandomEndpoint(nil, false) require.Nil(t, c) require.Equal(t, 0, failedCount) }) t.Run("One", func(t *testing.T) { for _, goodState := range []conn.State{conn.Online, conn.Offline, conn.Created} { - c, failedCount := s.selectRandomConnection([]conn.Conn{&mock.Conn{AddrField: "asd", State: goodState}}, false) + c, failedCount := s.selectRandomEndpoint([]conn.Conn{&mock.Conn{AddrField: "asd", State: goodState}}, false) require.Equal(t, &mock.Conn{AddrField: "asd", State: goodState}, c) require.Equal(t, 0, failedCount) } }) t.Run("OneBanned", func(t *testing.T) { - c, failedCount := s.selectRandomConnection([]conn.Conn{&mock.Conn{AddrField: "asd", State: conn.Banned}}, false) + c, failedCount := s.selectRandomEndpoint([]conn.Conn{&mock.Conn{AddrField: "asd", State: conn.Banned}}, false) require.Nil(t, c) require.Equal(t, 1, failedCount) - c, failedCount = s.selectRandomConnection([]conn.Conn{&mock.Conn{AddrField: "asd", State: conn.Banned}}, true) + c, failedCount = s.selectRandomEndpoint([]conn.Conn{&mock.Conn{AddrField: "asd", State: conn.Banned}}, true) require.Equal(t, &mock.Conn{AddrField: "asd", State: conn.Banned}, c) require.Equal(t, 0, failedCount) }) @@ -190,7 +190,7 @@ func TestSelectRandomConnection(t *testing.T) { first := 0 second := 0 for i := 0; i < 100; i++ { - c, _ := s.selectRandomConnection(conns, false) + c, _ := s.selectRandomEndpoint(conns, false) if c.Endpoint().Address() == "1" { first++ } else { @@ -208,7 +208,7 @@ func TestSelectRandomConnection(t *testing.T) { } totalFailed := 0 for i := 0; i < 100; i++ { - c, failed := s.selectRandomConnection(conns, false) + c, failed := s.selectRandomEndpoint(conns, false) require.Nil(t, c) totalFailed += failed } @@ -224,7 +224,7 @@ func TestSelectRandomConnection(t *testing.T) { second := 0 failed := 0 for i := 0; i < 100; i++ { - c, checkFailed := s.selectRandomConnection(conns, false) + c, checkFailed := s.selectRandomEndpoint(conns, false) failed += checkFailed switch c.Endpoint().Address() { case "1": @@ -252,10 +252,10 @@ func TestNewState(t *testing.T) { name: "Empty", state: newConnectionsState(nil, nil, balancerConfig.Info{}, false), res: &connectionsState{ - connByNodeID: nil, - prefer: nil, - fallback: nil, - all: nil, + index: nil, + prefer: nil, + fallback: nil, + all: nil, }, }, { @@ -265,7 +265,7 @@ func TestNewState(t *testing.T) { &mock.Conn{AddrField: "2", NodeIDField: 2}, }, nil, balancerConfig.Info{}, false), res: &connectionsState{ - connByNodeID: map[uint32]conn.Conn{ + index: map[uint32]conn.Conn{ 1: &mock.Conn{AddrField: "1", NodeIDField: 1}, 2: &mock.Conn{AddrField: "2", NodeIDField: 2}, }, @@ -291,7 +291,7 @@ func TestNewState(t *testing.T) { return info.SelfLocation == c.Endpoint().Location() }), balancerConfig.Info{SelfLocation: "t"}, false), res: &connectionsState{ - connByNodeID: map[uint32]conn.Conn{ + index: map[uint32]conn.Conn{ 1: &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, 2: &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, 3: &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, @@ -319,7 +319,7 @@ func TestNewState(t *testing.T) { return info.SelfLocation == c.Endpoint().Location() }), balancerConfig.Info{SelfLocation: "t"}, true), res: &connectionsState{ - connByNodeID: map[uint32]conn.Conn{ + index: map[uint32]conn.Conn{ 1: &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, 2: &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, 3: &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, @@ -352,7 +352,7 @@ func TestNewState(t *testing.T) { return info.SelfLocation == c.Endpoint().Location() }), balancerConfig.Info{SelfLocation: "t"}, true), res: &connectionsState{ - connByNodeID: map[uint32]conn.Conn{ + index: map[uint32]conn.Conn{ 1: &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, 2: &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, 3: &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, @@ -388,7 +388,7 @@ func TestNewState(t *testing.T) { func TestConnection(t *testing.T) { t.Run("Empty", func(t *testing.T) { s := newConnectionsState(nil, nil, balancerConfig.Info{}, false) - c, failed := s.GetConnection(context.Background()) + c, failed := s.Next(context.Background()) require.Nil(t, c) require.Equal(t, 0, failed) }) @@ -397,7 +397,7 @@ func TestConnection(t *testing.T) { &mock.Conn{AddrField: "1", State: conn.Online}, &mock.Conn{AddrField: "2", State: conn.Online}, }, nil, balancerConfig.Info{}, false) - c, failed := s.GetConnection(context.Background()) + c, failed := s.Next(context.Background()) require.NotNil(t, c) require.Equal(t, 0, failed) }) @@ -406,7 +406,7 @@ func TestConnection(t *testing.T) { &mock.Conn{AddrField: "1", State: conn.Online}, &mock.Conn{AddrField: "2", State: conn.Banned}, }, nil, balancerConfig.Info{}, false) - c, _ := s.GetConnection(context.Background()) + c, _ := s.Next(context.Background()) require.Equal(t, &mock.Conn{AddrField: "1", State: conn.Online}, c) }) t.Run("AllBanned", func(t *testing.T) { @@ -419,7 +419,7 @@ func TestConnection(t *testing.T) { preferred := 0 fallback := 0 for i := 0; i < 100; i++ { - c, failed := s.GetConnection(context.Background()) + c, failed := s.Next(context.Background()) require.NotNil(t, c) require.Equal(t, 2, failed) if c.Endpoint().Address() == "t1" { @@ -439,7 +439,7 @@ func TestConnection(t *testing.T) { }, filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { return c.Endpoint().Location() == info.SelfLocation }), balancerConfig.Info{SelfLocation: "t"}, true) - c, failed := s.GetConnection(context.Background()) + c, failed := s.Next(context.Background()) require.Equal(t, &mock.Conn{AddrField: "f2", State: conn.Online, LocationField: "f"}, c) require.Equal(t, 1, failed) }) @@ -448,7 +448,7 @@ func TestConnection(t *testing.T) { &mock.Conn{AddrField: "1", State: conn.Online, NodeIDField: 1}, &mock.Conn{AddrField: "2", State: conn.Online, NodeIDField: 2}, }, nil, balancerConfig.Info{}, false) - c, failed := s.GetConnection(endpoint.WithNodeID(context.Background(), 2)) + c, failed := s.Next(endpoint.WithNodeID(context.Background(), 2)) require.Equal(t, &mock.Conn{AddrField: "2", State: conn.Online, NodeIDField: 2}, c) require.Equal(t, 0, failed) }) @@ -457,7 +457,7 @@ func TestConnection(t *testing.T) { &mock.Conn{AddrField: "1", State: conn.Online, NodeIDField: 1}, &mock.Conn{AddrField: "2", State: conn.Unknown, NodeIDField: 2}, }, nil, balancerConfig.Info{}, false) - c, failed := s.GetConnection(endpoint.WithNodeID(context.Background(), 2)) + c, failed := s.Next(endpoint.WithNodeID(context.Background(), 2)) require.Equal(t, &mock.Conn{AddrField: "1", State: conn.Online, NodeIDField: 1}, c) require.Equal(t, 0, failed) }) diff --git a/internal/balancer/local_dc_test.go b/internal/balancer/local_dc_test.go index 2eab1e9a8..0924c5a18 100644 --- a/internal/balancer/local_dc_test.go +++ b/internal/balancer/local_dc_test.go @@ -151,7 +151,7 @@ func TestLocalDCDiscovery(t *testing.T) { require.NoError(t, err) for i := 0; i < 100; i++ { - conn, _ := r.connections().GetConnection(ctx) + conn, _ := r.connections().Next(ctx) require.Equal(t, "b:234", conn.Endpoint().Address()) require.Equal(t, "b", conn.Endpoint().Location()) } diff --git a/internal/conn/config.go b/internal/conn/config.go index df82f3a85..50ebfa851 100644 --- a/internal/conn/config.go +++ b/internal/conn/config.go @@ -1,16 +1,12 @@ package conn import ( - "time" - "google.golang.org/grpc" "github.com/ydb-platform/ydb-go-sdk/v3/trace" ) type Config interface { - DialTimeout() time.Duration - ConnectionTTL() time.Duration Trace() *trace.Driver GrpcDialOptions() []grpc.DialOption } diff --git a/internal/conn/conn.go b/internal/conn/conn.go index 431be9bff..6118c1d49 100644 --- a/internal/conn/conn.go +++ b/internal/conn/conn.go @@ -3,16 +3,15 @@ package conn import ( "context" "fmt" - "sync" "sync/atomic" "time" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" "google.golang.org/grpc" - "google.golang.org/grpc/connectivity" "google.golang.org/grpc/metadata" "google.golang.org/grpc/stats" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/closer" "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" "github.com/ydb-platform/ydb-go-sdk/v3/internal/meta" "github.com/ydb-platform/ydb-go-sdk/v3/internal/operation" @@ -36,29 +35,20 @@ var ( type Conn interface { grpc.ClientConnInterface + closer.Closer Endpoint() endpoint.Endpoint LastUsage() time.Time - - Ping(ctx context.Context) error - IsState(states ...State) bool - GetState() State - SetState(ctx context.Context, state State) State - Unban(ctx context.Context) State } type conn struct { - mtx sync.RWMutex + *grpc.ClientConn config Config // ro access - grpcConn *grpc.ClientConn done chan struct{} endpoint endpoint.Endpoint // ro access - closed bool - state atomic.Uint32 childStreams *xcontext.CancelsGuard lastUsage xsync.LastUsage - onClose []func(*conn) onTransportErrors []func(ctx context.Context, cc Conn, cause error) } @@ -66,36 +56,10 @@ func (c *conn) Address() string { return c.endpoint.Address() } -func (c *conn) Ping(ctx context.Context) error { - cc, err := c.realConn(ctx) - if err != nil { - return c.wrapError(err) - } - if !isAvailable(cc) { - return c.wrapError(errUnavailableConnection) - } - - return nil -} - func (c *conn) LastUsage() time.Time { - c.mtx.RLock() - defer c.mtx.RUnlock() - return c.lastUsage.Get() } -func (c *conn) IsState(states ...State) bool { - state := State(c.state.Load()) - for _, s := range states { - if s == state { - return true - } - } - - return false -} - func (c *conn) NodeID() uint32 { if c != nil { return c.endpoint.NodeID() @@ -104,36 +68,6 @@ func (c *conn) NodeID() uint32 { return 0 } -func (c *conn) park(ctx context.Context) (err error) { - onDone := trace.DriverOnConnPark( - c.config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).park"), - c.Endpoint(), - ) - defer func() { - onDone(err) - }() - - c.mtx.Lock() - defer c.mtx.Unlock() - - if c.closed { - return nil - } - - if c.grpcConn == nil { - return nil - } - - err = c.close(ctx) - - if err != nil { - return c.wrapError(err) - } - - return nil -} - func (c *conn) Endpoint() endpoint.Endpoint { if c != nil { return c.endpoint @@ -142,141 +76,13 @@ func (c *conn) Endpoint() endpoint.Endpoint { return nil } -func (c *conn) SetState(ctx context.Context, s State) State { - return c.setState(ctx, s) -} - -func (c *conn) setState(ctx context.Context, s State) State { - if state := State(c.state.Swap(uint32(s))); state != s { - trace.DriverOnConnStateChange( - c.config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).setState"), - c.endpoint.Copy(), state, - )(s) - } - - return s -} - -func (c *conn) Unban(ctx context.Context) State { - var newState State - c.mtx.RLock() - cc := c.grpcConn //nolint:ifshort - c.mtx.RUnlock() - if isAvailable(cc) { - newState = Online - } else { - newState = Offline - } - - c.setState(ctx, newState) - - return newState -} - -func (c *conn) GetState() (s State) { - return State(c.state.Load()) -} - -func (c *conn) realConn(ctx context.Context) (cc *grpc.ClientConn, err error) { - if c.isClosed() { - return nil, c.wrapError(errClosedConnection) - } - - c.mtx.Lock() - defer c.mtx.Unlock() - - if c.grpcConn != nil { - return c.grpcConn, nil - } - - if dialTimeout := c.config.DialTimeout(); dialTimeout > 0 { - var cancel context.CancelFunc - ctx, cancel = xcontext.WithTimeout(ctx, dialTimeout) - defer cancel() - } - - onDone := trace.DriverOnConnDial( - c.config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).realConn"), - c.endpoint.Copy(), - ) - defer func() { - onDone(err) - }() - - // prepend "ydb" scheme for grpc dns-resolver to find the proper scheme - // three slashes in "ydb:///" is ok. It needs for good parse scheme in grpc resolver. - address := "ydb:///" + c.endpoint.Address() - - cc, err = grpc.DialContext(ctx, address, append( //nolint:staticcheck,nolintlint - []grpc.DialOption{ - grpc.WithStatsHandler(statsHandler{}), - }, c.config.GrpcDialOptions()..., - )...) - if err != nil { - if xerrors.IsContextError(err) { - return nil, xerrors.WithStackTrace(err) - } - - defer func() { - c.onTransportError(ctx, err) - }() - - err = xerrors.Transport(err, - xerrors.WithAddress(address), - ) - - return nil, c.wrapError( - xerrors.Retryable(err, - xerrors.WithName("realConn"), - ), - ) - } - - c.grpcConn = cc - c.setState(ctx, Online) - - return c.grpcConn, nil -} - func (c *conn) onTransportError(ctx context.Context, cause error) { for _, onTransportError := range c.onTransportErrors { onTransportError(ctx, c, cause) } } -func isAvailable(raw *grpc.ClientConn) bool { - return raw != nil && raw.GetState() == connectivity.Ready -} - -// conn must be locked -func (c *conn) close(ctx context.Context) (err error) { - if c.grpcConn == nil { - return nil - } - err = c.grpcConn.Close() - c.grpcConn = nil - c.setState(ctx, Offline) - - return c.wrapError(err) -} - -func (c *conn) isClosed() bool { - c.mtx.RLock() - defer c.mtx.RUnlock() - - return c.closed -} - func (c *conn) Close(ctx context.Context) (err error) { - c.mtx.Lock() - defer c.mtx.Unlock() - - if c.closed { - return nil - } - onDone := trace.DriverOnConnClose( c.config.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).Close"), @@ -286,23 +92,17 @@ func (c *conn) Close(ctx context.Context) (err error) { onDone(err) }() - c.closed = true - - err = c.close(ctx) + c.childStreams.Cancel() - c.setState(ctx, Destroyed) - - for _, onClose := range c.onClose { - onClose(c) + err = c.ClientConn.Close() + if err != nil { + return xerrors.WithStackTrace(err) } - return c.wrapError(err) + return nil } -var ( - onTransportErrorStub = func(ctx context.Context, err error) {} - wrapErrorStub = func(err error) error { return err } -) +var onTransportErrorStub = func(ctx context.Context, err error) {} //nolint:funlen func invoke( @@ -312,7 +112,6 @@ func invoke( cc grpc.ClientConnInterface, onTransportError func(context.Context, error), address string, - wrapError func(err error) error, opts ...grpc.CallOption, ) ( opID string, @@ -332,28 +131,28 @@ func invoke( onTransportError = onTransportErrorStub } - if wrapError == nil { - wrapError = wrapErrorStub - } - err = cc.Invoke(ctx, method, req, reply, opts...) if err != nil { if xerrors.IsContextError(err) { return opID, issues, xerrors.WithStackTrace(err) } - defer onTransportError(ctx, err) + defer func() { + onTransportError(ctx, err) + }() if useWrapping { - err = xerrors.Transport(err, - xerrors.WithAddress(address), - xerrors.WithTraceID(traceID), - ) if sentMark.canRetry() { - return opID, issues, wrapError(xerrors.Retryable(err, xerrors.WithName("Invoke"))) + return opID, issues, xerrors.Retryable(xerrors.Transport(err, + xerrors.WithAddress(address), + xerrors.WithTraceID(traceID), + ), xerrors.WithName("Invoke")) } - return opID, issues, wrapError(err) + return opID, issues, xerrors.WithStackTrace(xerrors.Transport(err, + xerrors.WithAddress(address), + xerrors.WithTraceID(traceID), + )) } return opID, issues, err @@ -368,10 +167,10 @@ func invoke( if useWrapping { switch { case !t.GetOperation().GetReady(): - return opID, issues, wrapError(errOperationNotReady) + return opID, issues, xerrors.WithStackTrace(errOperationNotReady) case t.GetOperation().GetStatus() != Ydb.StatusIds_SUCCESS: - return opID, issues, wrapError( + return opID, issues, xerrors.WithStackTrace( xerrors.Operation( xerrors.FromOperation(t.GetOperation()), xerrors.WithAddress(address), @@ -386,7 +185,7 @@ func invoke( } if useWrapping { if t.GetStatus() != Ydb.StatusIds_SUCCESS { - return opID, issues, wrapError( + return opID, issues, xerrors.WithStackTrace( xerrors.Operation( xerrors.FromOperation(t), xerrors.WithAddress(address), @@ -415,19 +214,13 @@ func (c *conn) Invoke( stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).Invoke"), c.endpoint, trace.Method(method), ) - cc *grpc.ClientConn md = metadata.MD{} ) defer func() { meta.CallTrailerCallback(ctx, md) - onDone(err, issues, opID, c.GetState(), md) + onDone(err, issues, opID, c.ClientConn.GetState(), md) }() - cc, err = c.realConn(ctx) - if err != nil { - return c.wrapError(err) - } - stop := c.lastUsage.Start() defer stop() @@ -436,10 +229,9 @@ func (c *conn) Invoke( method, req, res, - cc, + c.ClientConn, c.onTransportError, c.Address(), - c.wrapError, append(opts, grpc.Trailer(&md))..., ) @@ -466,11 +258,6 @@ func (c *conn) NewStream( onDone(finalErr, c.GetState()) }() - cc, err := c.realConn(ctx) - if err != nil { - return nil, c.wrapError(err) - } - stop := c.lastUsage.Start() defer stop() @@ -497,7 +284,7 @@ func (c *conn) NewStream( sentMark: sentMark, } - s.stream, err = cc.NewStream(ctx, desc, method, append(opts, grpc.OnFinish(s.finish))...) + s.stream, err = c.ClientConn.NewStream(ctx, desc, method, append(opts, grpc.OnFinish(s.finish))...) if err != nil { if xerrors.IsContextError(err) { return nil, xerrors.WithStackTrace(err) @@ -507,44 +294,34 @@ func (c *conn) NewStream( c.onTransportError(ctx, err) }() - if useWrapping { - err = xerrors.Transport(err, - xerrors.WithAddress(c.Address()), - xerrors.WithTraceID(traceID), - ) - if sentMark.canRetry() { - return nil, c.wrapError(xerrors.Retryable(err, xerrors.WithName("NewStream"))) - } + if !useWrapping { + return nil, err + } - return nil, c.wrapError(err) + if sentMark.canRetry() { + return nil, xerrors.WithStackTrace(xerrors.Retryable( + xerrors.Transport(err, + xerrors.WithAddress(c.Address()), + xerrors.WithNodeID(c.NodeID()), + xerrors.WithTraceID(traceID), + ), + xerrors.WithName("NewStream")), + ) } - return nil, err + return nil, xerrors.WithStackTrace(xerrors.Transport(err, + xerrors.WithAddress(c.Address()), + xerrors.WithNodeID(c.NodeID()), + xerrors.WithTraceID(traceID), + )) } return s, nil } -func (c *conn) wrapError(err error) error { - if err == nil { - return nil - } - nodeErr := newConnError(c.endpoint.NodeID(), c.endpoint.Address(), err) - - return xerrors.WithStackTrace(nodeErr, xerrors.WithSkipDepth(1)) -} - type option func(c *conn) -func withOnClose(onClose func(*conn)) option { - return func(c *conn) { - if onClose != nil { - c.onClose = append(c.onClose, onClose) - } - } -} - -func withOnTransportError(onTransportError func(ctx context.Context, cc Conn, cause error)) option { +func WithOnTransportError(onTransportError func(ctx context.Context, cc Conn, cause error)) option { return func(c *conn) { if onTransportError != nil { c.onTransportErrors = append(c.onTransportErrors, onTransportError) @@ -552,31 +329,52 @@ func withOnTransportError(onTransportError func(ctx context.Context, cc Conn, ca } } -func newConn(e endpoint.Endpoint, config Config, opts ...option) *conn { +func New(e endpoint.Endpoint, config Config, opts ...option) (_ *conn, err error) { c := &conn{ endpoint: e, config: config, done: make(chan struct{}), lastUsage: xsync.NewLastUsage(), childStreams: xcontext.NewCancelsGuard(), - onClose: []func(*conn){ - func(c *conn) { - c.childStreams.Cancel() - }, - }, } - c.state.Store(uint32(Created)) + for _, opt := range opts { if opt != nil { opt(c) } } - return c -} + // prepend "ydb" scheme for grpc dns-resolver to find the proper scheme + // three slashes in "ydb:///" is ok. It needs for good parse scheme in grpc resolver. + address := "ydb:///" + c.endpoint.Address() + + c.ClientConn, err = grpc.NewClient(address, append( //nolint:staticcheck,nolintlint + []grpc.DialOption{ + grpc.WithStatsHandler(statsHandler{}), + //grpc.WithUnaryInterceptor(func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + // err := invoker(ctx, method, req, reply, cc, opts...) + // if err != nil { + // return xerrors.WithStackTrace(xerrors.Transport(err, + // xerrors.WithNodeID(c.NodeID()), + // xerrors.WithAddress(c.endpoint.Address()), + // )) + // } + //}), + //grpc.WithStreamInterceptor(func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + // + //}), + //grpc.WithChainStreamInterceptor(func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + // + //}), + //grpc_middleware.ChainStreamClient(), + }, c.config.GrpcDialOptions()..., + )...) + + if err != nil { + return nil, xerrors.WithStackTrace(err) + } -func New(e endpoint.Endpoint, config Config, opts ...option) Conn { - return newConn(e, config, opts...) + return c, nil } var _ stats.Handler = statsHandler{} diff --git a/internal/conn/conn_test.go b/internal/conn/conn_test.go index af8f86cb6..27930c12a 100644 --- a/internal/conn/conn_test.go +++ b/internal/conn/conn_test.go @@ -30,7 +30,7 @@ type connMock struct { } func (c connMock) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error { - _, _, err := invoke(ctx, method, args, reply, c.cc, nil, "", nil, opts...) + _, _, err := invoke(ctx, method, args, reply, c.cc, nil, "", opts...) return err } diff --git a/internal/conn/error.go b/internal/conn/error.go deleted file mode 100644 index 216ac2f7b..000000000 --- a/internal/conn/error.go +++ /dev/null @@ -1,25 +0,0 @@ -package conn - -import "fmt" - -type connError struct { - nodeID uint32 - endpoint string - err error -} - -func newConnError(id uint32, endpoint string, err error) connError { - return connError{ - nodeID: id, - endpoint: endpoint, - err: err, - } -} - -func (n connError) Error() string { - return fmt.Sprintf("connError{node_id:%d,address:'%s'}: %v", n.nodeID, n.endpoint, n.err) -} - -func (n connError) Unwrap() error { - return n.err -} diff --git a/internal/conn/error_test.go b/internal/conn/error_test.go deleted file mode 100644 index df3927837..000000000 --- a/internal/conn/error_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package conn - -import ( - "errors" - "testing" - - "github.com/stretchr/testify/require" - "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" - - "github.com/ydb-platform/ydb-go-sdk/v3/internal/backoff" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" -) - -func TestNodeErrorError(t *testing.T) { - testErr := errors.New("test") - nodeErr := newConnError(1, "localhost:1234", testErr) - message := nodeErr.Error() - - require.Equal(t, "connError{node_id:1,address:'localhost:1234'}: test", message) -} - -func TestNodeErrorUnwrap(t *testing.T) { - testErr := errors.New("test") - nodeErr := newConnError(1, "asd", testErr) - - unwrapped := errors.Unwrap(nodeErr) - require.Equal(t, testErr, unwrapped) -} - -func TestNodeErrorIs(t *testing.T) { - testErr := errors.New("test") - testErr2 := errors.New("test2") - nodeErr := newConnError(1, "localhost:1234", testErr) - - require.ErrorIs(t, nodeErr, testErr) - require.NotErrorIs(t, nodeErr, testErr2) -} - -type testType1Error struct { - msg string -} - -func (t testType1Error) Error() string { - return "1 - " + t.msg -} - -type testType2Error struct { - msg string -} - -func (t testType2Error) Error() string { - return "2 - " + t.msg -} - -func TestNodeErrorAs(t *testing.T) { - testErr := testType1Error{msg: "test"} - nodeErr := newConnError(1, "localhost:1234", testErr) - - target := testType1Error{} - require.ErrorAs(t, nodeErr, &target) - require.Equal(t, testErr, target) - - target2 := testType2Error{} - require.False(t, errors.As(nodeErr, &target2)) -} - -// https://github.com/ydb-platform/ydb-go-sdk/issues/1227 -func TestIssue1227NodeErrorUnwrapBadSession(t *testing.T) { - nodeErr := xerrors.WithStackTrace(newConnError(1, "localhost:1234", xerrors.Operation( - xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION), - ))) - - code, errType, backoffType, invalidObject := xerrors.Check(nodeErr) - - require.EqualValues(t, Ydb.StatusIds_BAD_SESSION, code) - require.EqualValues(t, xerrors.TypeRetryable, errType) - require.EqualValues(t, backoff.TypeNoBackoff, backoffType) - require.True(t, invalidObject) -} diff --git a/internal/conn/grpc_client_stream.go b/internal/conn/grpc_client_stream.go index ee106db96..db3919696 100644 --- a/internal/conn/grpc_client_stream.go +++ b/internal/conn/grpc_client_stream.go @@ -58,17 +58,17 @@ func (s *grpcClientStream) CloseSend() (err error) { return xerrors.WithStackTrace(err) } - if s.wrapping { - return s.wrapError( - xerrors.Transport( - err, - xerrors.WithAddress(s.parentConn.Address()), - xerrors.WithTraceID(s.traceID), - ), - ) + if !s.wrapping { + return err } - return s.wrapError(err) + return xerrors.WithStackTrace( + xerrors.Transport(err, + xerrors.WithAddress(s.parentConn.Address()), + xerrors.WithNodeID(s.parentConn.NodeID()), + xerrors.WithTraceID(s.traceID), + ), + ) } return nil @@ -100,17 +100,22 @@ func (s *grpcClientStream) SendMsg(m interface{}) (err error) { }() if s.wrapping { - err = xerrors.Transport(err, - xerrors.WithAddress(s.parentConn.Address()), - xerrors.WithTraceID(s.traceID), - ) if s.sentMark.canRetry() { - return s.wrapError(xerrors.Retryable(err, + return xerrors.WithStackTrace(xerrors.Retryable( + xerrors.Transport(err, + xerrors.WithAddress(s.parentConn.Address()), + xerrors.WithNodeID(s.parentConn.NodeID()), + xerrors.WithTraceID(s.traceID), + ), xerrors.WithName("SendMsg"), )) } - return s.wrapError(err) + return xerrors.WithStackTrace(xerrors.Transport(err, + xerrors.WithAddress(s.parentConn.Address()), + xerrors.WithNodeID(s.parentConn.NodeID()), + xerrors.WithTraceID(s.traceID), + )) } return err @@ -150,53 +155,49 @@ func (s *grpcClientStream) RecvMsg(m interface{}) (err error) { //nolint:funlen return io.EOF } - if xerrors.IsContextError(err) { - return xerrors.WithStackTrace(err) - } - defer func() { s.parentConn.onTransportError(ctx, err) }() - if s.wrapping { - err = xerrors.Transport(err, - xerrors.WithAddress(s.parentConn.Address()), - ) - if s.sentMark.canRetry() { - return s.wrapError(xerrors.Retryable(err, - xerrors.WithName("RecvMsg"), - )) - } - - return s.wrapError(err) + if !s.wrapping { + return err } - return err - } + if xerrors.IsContextError(err) { + return xerrors.WithStackTrace(err) + } - if s.wrapping { - if operation, ok := m.(operation.Status); ok { - if status := operation.GetStatus(); status != Ydb.StatusIds_SUCCESS { - return s.wrapError( - xerrors.Operation( - xerrors.FromOperation(operation), - xerrors.WithAddress(s.parentConn.Address()), - ), - ) - } + if s.sentMark.canRetry() { + return xerrors.WithStackTrace(xerrors.Retryable( + xerrors.Transport(err, + xerrors.WithAddress(s.parentConn.Address()), + xerrors.WithNodeID(s.parentConn.NodeID()), + ), + xerrors.WithName("RecvMsg"), + )) } - } - return nil -} + return xerrors.WithStackTrace(xerrors.Transport(err, + xerrors.WithAddress(s.parentConn.Address()), + xerrors.WithNodeID(s.parentConn.NodeID()), + )) + } -func (s *grpcClientStream) wrapError(err error) error { - if err == nil { + if !s.wrapping { return nil } - return xerrors.WithStackTrace( - newConnError(s.parentConn.endpoint.NodeID(), s.parentConn.endpoint.Address(), err), - xerrors.WithSkipDepth(1), - ) + if operation, ok := m.(operation.Status); ok { + if status := operation.GetStatus(); status != Ydb.StatusIds_SUCCESS { + return xerrors.WithStackTrace( + xerrors.Operation( + xerrors.FromOperation(operation), + xerrors.WithAddress(s.parentConn.Address()), + xerrors.WithNodeID(s.parentConn.NodeID()), + ), + ) + } + } + + return nil } diff --git a/internal/conn/pool.go b/internal/conn/pool.go deleted file mode 100644 index 783b7a880..000000000 --- a/internal/conn/pool.go +++ /dev/null @@ -1,253 +0,0 @@ -package conn - -import ( - "context" - "sync" - "sync/atomic" - "time" - - "google.golang.org/grpc" - grpcCodes "google.golang.org/grpc/codes" - - "github.com/ydb-platform/ydb-go-sdk/v3/internal/closer" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync" - "github.com/ydb-platform/ydb-go-sdk/v3/trace" -) - -type connsKey struct { - address string - nodeID uint32 -} - -type Pool struct { - usages int64 - config Config - mtx xsync.RWMutex - opts []grpc.DialOption - conns map[connsKey]*conn - done chan struct{} -} - -func (p *Pool) Get(endpoint endpoint.Endpoint) Conn { - p.mtx.Lock() - defer p.mtx.Unlock() - - var ( - address = endpoint.Address() - cc *conn - has bool - ) - - key := connsKey{address, endpoint.NodeID()} - - if cc, has = p.conns[key]; has { - return cc - } - - cc = newConn( - endpoint, - p.config, - withOnClose(p.remove), - withOnTransportError(p.Ban), - ) - - p.conns[key] = cc - - return cc -} - -func (p *Pool) remove(c *conn) { - p.mtx.Lock() - defer p.mtx.Unlock() - delete(p.conns, connsKey{c.Endpoint().Address(), c.Endpoint().NodeID()}) -} - -func (p *Pool) isClosed() bool { - select { - case <-p.done: - return true - default: - return false - } -} - -func (p *Pool) Ban(ctx context.Context, cc Conn, cause error) { - if p.isClosed() { - return - } - - if !xerrors.IsTransportError(cause, - grpcCodes.ResourceExhausted, - grpcCodes.Unavailable, - // grpcCodes.OK, - // grpcCodes.Canceled, - // grpcCodes.Unknown, - // grpcCodes.InvalidArgument, - // grpcCodes.DeadlineExceeded, - // grpcCodes.NotFound, - // grpcCodes.AlreadyExists, - // grpcCodes.PermissionDenied, - // grpcCodes.FailedPrecondition, - // grpcCodes.Aborted, - // grpcCodes.OutOfRange, - // grpcCodes.Unimplemented, - // grpcCodes.Internal, - // grpcCodes.DataLoss, - // grpcCodes.Unauthenticated, - ) { - return - } - - e := cc.Endpoint().Copy() - - p.mtx.RLock() - defer p.mtx.RUnlock() - - cc, ok := p.conns[connsKey{e.Address(), e.NodeID()}] - if !ok { - return - } - - trace.DriverOnConnBan( - p.config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*Pool).Ban"), - e, cc.GetState(), cause, - )(cc.SetState(ctx, Banned)) -} - -func (p *Pool) Allow(ctx context.Context, cc Conn) { - if p.isClosed() { - return - } - - e := cc.Endpoint().Copy() - - p.mtx.RLock() - defer p.mtx.RUnlock() - - cc, ok := p.conns[connsKey{e.Address(), e.NodeID()}] - if !ok { - return - } - - trace.DriverOnConnAllow( - p.config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*Pool).Allow"), - e, cc.GetState(), - )(cc.Unban(ctx)) -} - -func (p *Pool) Take(context.Context) error { - atomic.AddInt64(&p.usages, 1) - - return nil -} - -func (p *Pool) Release(ctx context.Context) (finalErr error) { - onDone := trace.DriverOnPoolRelease(p.config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*Pool).Release"), - ) - defer func() { - onDone(finalErr) - }() - - if atomic.AddInt64(&p.usages, -1) > 0 { - return nil - } - - close(p.done) - - var conns []closer.Closer - p.mtx.WithRLock(func() { - conns = make([]closer.Closer, 0, len(p.conns)) - for _, c := range p.conns { - conns = append(conns, c) - } - }) - - var ( - errCh = make(chan error, len(conns)) - wg sync.WaitGroup - ) - - wg.Add(len(conns)) - for _, c := range conns { - go func(c closer.Closer) { - defer wg.Done() - if err := c.Close(ctx); err != nil { - errCh <- err - } - }(c) - } - wg.Wait() - close(errCh) - - issues := make([]error, 0, len(conns)) - for err := range errCh { - issues = append(issues, err) - } - - if len(issues) > 0 { - return xerrors.WithStackTrace(xerrors.NewWithIssues("connection pool close failed", issues...)) - } - - return nil -} - -func (p *Pool) connParker(ctx context.Context, ttl, interval time.Duration) { - ticker := time.NewTicker(interval) - defer ticker.Stop() - for { - select { - case <-p.done: - return - case <-ticker.C: - for _, c := range p.collectConns() { - if time.Since(c.LastUsage()) > ttl { - switch c.GetState() { - case Online, Banned: - _ = c.park(ctx) - default: - // nop - } - } - } - } - } -} - -func (p *Pool) collectConns() []*conn { - p.mtx.RLock() - defer p.mtx.RUnlock() - conns := make([]*conn, 0, len(p.conns)) - for _, c := range p.conns { - conns = append(conns, c) - } - - return conns -} - -func NewPool(ctx context.Context, config Config) *Pool { - onDone := trace.DriverOnPoolNew(config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.NewPool"), - ) - defer onDone() - - p := &Pool{ - usages: 1, - config: config, - opts: config.GrpcDialOptions(), - conns: make(map[connsKey]*conn), - done: make(chan struct{}), - } - - if ttl := config.ConnectionTTL(); ttl > 0 { - go p.connParker(xcontext.ValueOnly(ctx), ttl, ttl/2) //nolint:gomnd - } - - return p -} diff --git a/internal/endpoint/diff.go b/internal/endpoint/diff.go index 243c19706..1a1ed62b5 100644 --- a/internal/endpoint/diff.go +++ b/internal/endpoint/diff.go @@ -6,7 +6,11 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/xmath" ) -func Diff(previous, newest []Endpoint) (steady, added, dropped []Endpoint) { +func Diff(newest []Endpoint, previous []Endpoint) ( + steady []Endpoint, + added []Endpoint, + dropped []Endpoint, +) { steady = make([]Endpoint, 0, xmath.Min(len(newest), len(previous))) added = make([]Endpoint, 0, len(newest)) dropped = make([]Endpoint, 0, len(previous)) diff --git a/internal/endpoint/diff_test.go b/internal/endpoint/diff_test.go index aaf57e9cb..a3f091fdd 100644 --- a/internal/endpoint/diff_test.go +++ b/internal/endpoint/diff_test.go @@ -118,7 +118,7 @@ func TestDiff(t *testing.T) { }, } { t.Run(tt.name, func(t *testing.T) { - steady, added, dropped := Diff(tt.previous, tt.newest) + steady, added, dropped := Diff(tt.newest, tt.previous) require.Equal(t, endpointsToAddresses(tt.added), endpointsToAddresses(added)) require.Equal(t, endpointsToAddresses(tt.dropped), endpointsToAddresses(dropped)) require.Equal(t, endpointsToAddresses(tt.steady), endpointsToAddresses(steady)) diff --git a/internal/pool/defaults.go b/internal/pool/defaults.go index 2591e8438..1d10de92b 100644 --- a/internal/pool/defaults.go +++ b/internal/pool/defaults.go @@ -19,13 +19,5 @@ var defaultTrace = &Trace{ return func(info *WithDoneInfo) { } }, - OnPut: func(info *PutStartInfo) func(info *PutDoneInfo) { - return func(info *PutDoneInfo) { - } - }, - OnGet: func(info *GetStartInfo) func(info *GetDoneInfo) { - return func(info *GetDoneInfo) { - } - }, OnChange: func(info ChangeInfo) {}, } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 71995b054..2696dd5e6 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -2,17 +2,14 @@ package pool import ( "context" + "fmt" "time" - "golang.org/x/sync/errgroup" - - "github.com/ydb-platform/ydb-go-sdk/v3/internal/pool/stats" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync" "github.com/ydb-platform/ydb-go-sdk/v3/retry" - "github.com/ydb-platform/ydb-go-sdk/v3/trace" ) type ( @@ -23,13 +20,18 @@ type ( } safeStats struct { mu xsync.RWMutex - v stats.Stats - onChange func(stats.Stats) + v Stats + onChange func(Stats) } statsItemAddr struct { v *int onChange func(func()) } + lazyItem[PT Item[T], T any] struct { + mu xsync.Mutex + item PT + createItem func(ctx context.Context) (PT, error) + } Pool[PT Item[T], T any] struct { trace *Trace limit int @@ -38,16 +40,52 @@ type ( createTimeout time.Duration closeTimeout time.Duration - mu xsync.Mutex - idle []PT - index map[PT]struct{} + items chan *lazyItem[PT, T] done chan struct{} - - stats *safeStats } option[PT Item[T], T any] func(p *Pool[PT, T]) ) +func (p *Pool[PT, T]) Stats() Stats { + return Stats{ + Limit: p.limit, + Idle: len(p.items), + InUse: 0, + } +} + +func (item *lazyItem[PT, T]) get(ctx context.Context) (_ PT, err error) { + item.mu.Lock() + defer item.mu.Unlock() + + if item.item != nil { + if item.item.IsAlive() { + return item.item, nil + } + + _ = item.item.Close(ctx) + } + + item.item, err = item.createItem(ctx) + + return item.item, err +} + +func (item *lazyItem[PT, T]) close(ctx context.Context) error { + item.mu.Lock() + defer item.mu.Unlock() + + if item.item == nil { + return nil + } + + defer func() { + item.item = nil + }() + + return item.item.Close(ctx) +} + func (field statsItemAddr) Inc() { field.onChange(func() { *field.v++ @@ -60,28 +98,13 @@ func (field statsItemAddr) Dec() { }) } -func (s *safeStats) Get() stats.Stats { +func (s *safeStats) Get() Stats { s.mu.RLock() defer s.mu.RUnlock() return s.v } -func (s *safeStats) Index() statsItemAddr { - s.mu.RLock() - defer s.mu.RUnlock() - - return statsItemAddr{ - v: &s.v.Index, - onChange: func(f func()) { - s.mu.WithLock(f) - if s.onChange != nil { - s.onChange(s.Get()) - } - }, - } -} - func (s *safeStats) Idle() statsItemAddr { s.mu.RLock() defer s.mu.RUnlock() @@ -142,7 +165,7 @@ func WithTrace[PT Item[T], T any](t *Trace) option[PT, T] { } } -func New[PT Item[T], T any]( +func New[PT Item[T], T any]( //nolint:funlen ctx context.Context, opts ...option[PT, T], ) *Pool[PT, T] { @@ -170,262 +193,97 @@ func New[PT Item[T], T any]( }) }() - p.createItem = createItemWithTimeoutHandling(p.createItem, p) - - p.idle = make([]PT, 0, p.limit) - p.index = make(map[PT]struct{}, p.limit) - p.stats = &safeStats{ - v: stats.Stats{Limit: p.limit}, - onChange: p.trace.OnChange, - } - - return p -} - -// defaultCreateItem returns a new item -func defaultCreateItem[T any, PT Item[T]](ctx context.Context) (PT, error) { - var item T + createItem := p.createItem - return &item, nil -} + p.createItem = func(ctx context.Context) (PT, error) { + ctx, cancel := xcontext.WithDone(ctx, p.done) + defer cancel() -// createItemWithTimeoutHandling wraps the createItem function with timeout handling -func createItemWithTimeoutHandling[PT Item[T], T any]( - createItem func(ctx context.Context) (PT, error), - p *Pool[PT, T], -) func(ctx context.Context) (PT, error) { - return func(ctx context.Context) (PT, error) { var ( - ch = make(chan PT) - createErr error + createCtx = xcontext.ValueOnly(ctx) + cancelCreate context.CancelFunc ) - go func() { - defer close(ch) - createErr = createItemWithContext(ctx, p, createItem, ch) - }() - select { - case <-p.done: - return nil, xerrors.WithStackTrace(errClosedPool) - case <-ctx.Done(): - return nil, xerrors.WithStackTrace(ctx.Err()) - case item, has := <-ch: - if !has { - if ctxErr := ctx.Err(); ctxErr == nil && xerrors.IsContextError(createErr) { - return nil, xerrors.WithStackTrace(xerrors.Retryable(createErr)) - } - - return nil, xerrors.WithStackTrace(createErr) - } - - return item, nil + if t := p.createTimeout; t > 0 { + createCtx, cancelCreate = xcontext.WithTimeout(createCtx, t) + } else { + createCtx, cancelCreate = xcontext.WithCancel(createCtx) } - } -} + defer cancelCreate() -// createItemWithContext handles the creation of an item with context handling -func createItemWithContext[PT Item[T], T any]( - ctx context.Context, - p *Pool[PT, T], - createItem func(ctx context.Context) (PT, error), - ch chan PT, -) error { - var ( - createCtx = xcontext.ValueOnly(ctx) - cancelCreate context.CancelFunc - ) - - if d := p.createTimeout; d > 0 { - createCtx, cancelCreate = xcontext.WithTimeout(createCtx, d) - } else { - createCtx, cancelCreate = xcontext.WithCancel(createCtx) - } - defer cancelCreate() + newItem, err := createItem(createCtx) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } - newItem, err := createItem(createCtx) - if err != nil { - return xerrors.WithStackTrace(err) + return newItem, nil } - needCloseItem := true - defer func() { - if needCloseItem { - _ = p.closeItem(ctx, newItem) + p.items = make(chan *lazyItem[PT, T], p.limit) + for i := 0; i < p.limit; i++ { + p.items <- &lazyItem[PT, T]{ + createItem: p.createItem, } - }() - - select { - case <-p.done: - return xerrors.WithStackTrace(errClosedPool) - case <-ctx.Done(): - p.mu.Lock() - defer p.mu.Unlock() - - if len(p.index) < p.limit { - p.idle = append(p.idle, newItem) - p.index[newItem] = struct{}{} - p.stats.Index().Inc() - needCloseItem = false - } - - return xerrors.WithStackTrace(ctx.Err()) - case ch <- newItem: - needCloseItem = false - - return nil } -} -func (p *Pool[PT, T]) Stats() stats.Stats { - return p.stats.Get() + return p } -func (p *Pool[PT, T]) getItem(ctx context.Context) (_ PT, finalErr error) { - onDone := p.trace.OnGet(&GetStartInfo{ - Context: &ctx, - Call: stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/pool.(*Pool).getItem"), - }) - defer func() { - onDone(&GetDoneInfo{ - Error: finalErr, - }) - }() - - if err := ctx.Err(); err != nil { - return nil, xerrors.WithStackTrace(err) - } - - select { - case <-p.done: - return nil, xerrors.WithStackTrace(errClosedPool) - case <-ctx.Done(): - return nil, xerrors.WithStackTrace(ctx.Err()) - default: - var item PT - p.mu.WithLock(func() { - if len(p.idle) > 0 { - item, p.idle = p.idle[0], p.idle[1:] - p.stats.Idle().Dec() - } - }) - - if item != nil { - if item.IsAlive() { - return item, nil - } - _ = p.closeItem(ctx, item) - p.mu.WithLock(func() { - delete(p.index, item) - }) - p.stats.Index().Dec() - } - - item, err := p.createItem(ctx) - if err != nil { - return nil, xerrors.WithStackTrace(err) - } - - addedToIndex := false - p.mu.WithLock(func() { - if len(p.index) < p.limit { - p.index[item] = struct{}{} - addedToIndex = true - } - }) - if addedToIndex { - p.stats.Index().Inc() - } +// defaultCreateItem returns a new item +func defaultCreateItem[T any, PT Item[T]](ctx context.Context) (PT, error) { + var item T - return item, nil - } + return &item, nil } -func (p *Pool[PT, T]) putItem(ctx context.Context, item PT) (finalErr error) { - onDone := p.trace.OnPut(&PutStartInfo{ +func (p *Pool[PT, T]) try(ctx context.Context, f func(ctx context.Context, item PT) error) (finalErr error) { + onDone := p.trace.OnTry(&TryStartInfo{ Context: &ctx, - Call: stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/pool.(*Pool).putItem"), + Call: stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/pool.(*Pool).try"), }) + defer func() { - onDone(&PutDoneInfo{ + onDone(&TryDoneInfo{ Error: finalErr, }) }() - if err := ctx.Err(); err != nil { - return xerrors.WithStackTrace(err) - } - select { + case <-ctx.Done(): + return xerrors.WithStackTrace(ctx.Err()) + case <-p.done: return xerrors.WithStackTrace(errClosedPool) - default: - if !item.IsAlive() { - _ = p.closeItem(ctx, item) - - p.mu.WithLock(func() { - delete(p.index, item) - }) - p.stats.Index().Dec() - return xerrors.WithStackTrace(errItemIsNotAlive) + case lease, ok := <-p.items: + if !ok { + return xerrors.WithStackTrace(errClosedPool) } - p.mu.WithLock(func() { - p.idle = append(p.idle, item) - }) - p.stats.Idle().Inc() - - return nil - } -} - -func (p *Pool[PT, T]) closeItem(ctx context.Context, item PT) error { - ctx = xcontext.ValueOnly(ctx) - - var cancel context.CancelFunc - if d := p.closeTimeout; d > 0 { - ctx, cancel = xcontext.WithTimeout(ctx, d) - } else { - ctx, cancel = xcontext.WithCancel(ctx) - } - defer cancel() - - return item.Close(ctx) -} + defer func() { + p.items <- lease + }() -func (p *Pool[PT, T]) try(ctx context.Context, f func(ctx context.Context, item PT) error) (finalErr error) { - onDone := p.trace.OnTry(&TryStartInfo{ - Context: &ctx, - Call: stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/pool.(*Pool).try"), - }) - defer func() { - onDone(&TryDoneInfo{ - Error: finalErr, - }) - }() + item, err := lease.get(ctx) + if err != nil { + if ctx.Err() == nil { + return xerrors.WithStackTrace(xerrors.Retryable(err)) + } - item, err := p.getItem(ctx) - if err != nil { - if xerrors.IsYdb(err) { - return xerrors.WithStackTrace(xerrors.Retryable(err)) + return xerrors.WithStackTrace(err) } - return xerrors.WithStackTrace(err) - } - - defer func() { - _ = p.putItem(ctx, item) - }() + if !item.IsAlive() { + return xerrors.WithStackTrace(xerrors.Retryable(errItemIsNotAlive)) + } - p.stats.InUse().Inc() - defer p.stats.InUse().Dec() + err = f(ctx, item) + if err != nil { + return xerrors.WithStackTrace(err) + } - err = f(ctx, item) - if err != nil { - return xerrors.WithStackTrace(err) + return nil } - - return nil } func (p *Pool[PT, T]) With( @@ -454,13 +312,7 @@ func (p *Pool[PT, T]) With( } return nil - }, append(opts, retry.WithTrace(&trace.Retry{ - OnRetry: func(info trace.RetryLoopStartInfo) func(trace.RetryLoopDoneInfo) { - return func(info trace.RetryLoopDoneInfo) { - attempts = info.Attempts - } - }, - }))...) + }, opts...) if err != nil { return xerrors.WithStackTrace(err) } @@ -481,18 +333,23 @@ func (p *Pool[PT, T]) Close(ctx context.Context) (finalErr error) { close(p.done) - p.mu.Lock() - defer p.mu.Unlock() + errs := make([]error, 0, p.limit) - var g errgroup.Group - for item := range p.index { - item := item - g.Go(func() error { - return item.Close(ctx) - }) + for i := 0; i < p.limit; i++ { + select { + case <-ctx.Done(): + return xerrors.WithStackTrace(fmt.Errorf("%d items not closed: %w", p.limit-i, ctx.Err())) + case item := <-p.items: + if err := item.close(ctx); err != nil { + errs = append(errs, err) + } + } } - if err := g.Wait(); err != nil { - return xerrors.WithStackTrace(err) + + close(p.items) + + if len(errs) > 0 { + return xerrors.WithStackTrace(xerrors.Join(errs...)) } return nil diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 6ba7036c9..6b7baa9a7 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -189,6 +189,52 @@ func TestPool(t *testing.T) { }) }) }) + t.Run("Close", func(t *testing.T) { + t.Run("WithoutDeadline", func(t *testing.T) { + xtest.TestManyTimes(t, func(t testing.TB) { + require.NotPanics(t, func() { + p := New(rootCtx, + WithLimit[*testItem, testItem](10), + ) + wg := sync.WaitGroup{} + for range make([]struct{}, 1000) { + wg.Add(1) + go func() { + defer wg.Done() + _ = p.try(rootCtx, func(ctx context.Context, item *testItem) error { + return nil + }) + }() + } + p.Close(rootCtx) + require.Empty(t, p.items) + }) + }, xtest.StopAfter(3*time.Second)) + }) + t.Run("WithDeadline", func(t *testing.T) { + xtest.TestManyTimes(t, func(t testing.TB) { + require.NotPanics(t, func() { + p := New(rootCtx, + WithLimit[*testItem, testItem](10), + ) + wg := sync.WaitGroup{} + for range make([]struct{}, 1000) { + wg.Add(1) + go func() { + defer wg.Done() + _ = p.try(rootCtx, func(ctx context.Context, item *testItem) error { + return nil + }) + }() + } + ctx, cancel := context.WithTimeout(rootCtx, time.Second) + defer cancel() + p.Close(ctx) + require.Empty(t, p.items) + }) + }, xtest.StopAfter(3*time.Second)) + }) + }) t.Run("Item", func(t *testing.T) { t.Run("Close", func(t *testing.T) { xtest.TestManyTimes(t, func(t testing.TB) { @@ -302,12 +348,10 @@ func TestSafeStatsRace(t *testing.T) { go func() { defer wg.Done() require.NotPanics(t, func() { - switch rand.Int31n(4) { //nolint:gosec + switch rand.Int31n(3) { //nolint:gosec case 0: - s.Index().Inc() - case 1: s.InUse().Inc() - case 2: + case 1: s.Idle().Inc() default: s.Get() @@ -318,3 +362,22 @@ func TestSafeStatsRace(t *testing.T) { wg.Wait() }, xtest.StopAfter(5*time.Second)) } + +func BenchmarkPoolTry(b *testing.B) { + b.ReportAllocs() + ctx := xtest.Context(b) + p := New(ctx, + WithLimit[*testItem, testItem](100), + ) + defer func() { + _ = p.Close(ctx) + require.Empty(b, p.items) + }() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = p.try(ctx, func(ctx context.Context, item *testItem) error { + return nil + }) + } + }) +} diff --git a/internal/pool/stats/stats.go b/internal/pool/stats.go similarity index 69% rename from internal/pool/stats/stats.go rename to internal/pool/stats.go index dff03eaeb..a605e92c9 100644 --- a/internal/pool/stats/stats.go +++ b/internal/pool/stats.go @@ -1,8 +1,7 @@ -package stats +package pool type Stats struct { Limit int - Index int Idle int InUse int } diff --git a/internal/pool/trace.go b/internal/pool/trace.go index 40adef256..26f0d57d2 100644 --- a/internal/pool/trace.go +++ b/internal/pool/trace.go @@ -3,7 +3,6 @@ package pool import ( "context" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/pool/stats" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" ) @@ -13,8 +12,6 @@ type ( OnClose func(*CloseStartInfo) func(*CloseDoneInfo) OnTry func(*TryStartInfo) func(*TryDoneInfo) OnWith func(*WithStartInfo) func(*WithDoneInfo) - OnPut func(*PutStartInfo) func(*PutDoneInfo) - OnGet func(*GetStartInfo) func(*GetDoneInfo) OnChange func(ChangeInfo) } NewStartInfo struct { @@ -63,27 +60,5 @@ type ( Attempts int } - PutStartInfo struct { - // Context make available context in trace stack.Callerback function. - // Pointer to context provide replacement of context in trace stack.Callerback function. - // Warning: concurrent access to pointer on client side must be excluded. - // Safe replacement of context are provided only inside stack.Callerback function - Context *context.Context - Call stack.Caller - } - PutDoneInfo struct { - Error error - } - GetStartInfo struct { - // Context make available context in trace stack.Callerback function. - // Pointer to context provide replacement of context in trace stack.Callerback function. - // Warning: concurrent access to pointer on client side must be excluded. - // Safe replacement of context are provided only inside stack.Callerback function - Context *context.Context - Call stack.Caller - } - GetDoneInfo struct { - Error error - } - ChangeInfo = stats.Stats + ChangeInfo = Stats ) diff --git a/internal/query/client.go b/internal/query/client.go index 46a67046a..3f45b2396 100644 --- a/internal/query/client.go +++ b/internal/query/client.go @@ -9,7 +9,6 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/closer" "github.com/ydb-platform/ydb-go-sdk/v3/internal/pool" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/pool/stats" "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/config" "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/options" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" @@ -32,7 +31,7 @@ type ( sessionPool interface { closer.Closer - Stats() stats.Stats + Stats() pool.Stats With(ctx context.Context, f func(ctx context.Context, s *Session) error, opts ...retry.Option) error } poolStub struct { @@ -48,29 +47,28 @@ type ( } ) -func (pool *poolStub) Close(ctx context.Context) error { +func (p *poolStub) Close(ctx context.Context) error { return nil } -func (pool *poolStub) Stats() stats.Stats { - return stats.Stats{ +func (p *poolStub) Stats() pool.Stats { + return pool.Stats{ Limit: -1, - Index: 0, Idle: 0, - InUse: int(pool.InUse.Load()), + InUse: int(p.InUse.Load()), } } -func (pool *poolStub) With( +func (p *poolStub) With( ctx context.Context, f func(ctx context.Context, s *Session) error, opts ...retry.Option, ) error { - pool.InUse.Add(1) + p.InUse.Add(1) defer func() { - pool.InUse.Add(-1) + p.InUse.Add(-1) }() err := retry.Retry(ctx, func(ctx context.Context) (err error) { - s, err := pool.createSession(ctx) + s, err := p.createSession(ctx) if err != nil { return xerrors.WithStackTrace(err) } @@ -92,7 +90,7 @@ func (pool *poolStub) With( return nil } -func (c *Client) Stats() *stats.Stats { +func (c *Client) Stats() *pool.Stats { s := c.pool.Stats() return &s @@ -357,13 +355,19 @@ func New(ctx context.Context, balancer grpc.ClientConnInterface, cfg *config.Con ) defer onDone() - grpcClient := Ydb_Query_V1.NewQueryServiceClient(balancer) + var ( + grpcClient = Ydb_Query_V1.NewQueryServiceClient(balancer) + done = make(chan struct{}) + ) client := &Client{ config: cfg, client: grpcClient, - done: make(chan struct{}), + done: done, pool: newPool(ctx, cfg, func(ctx context.Context) (_ *Session, err error) { + ctx, cancel := xcontext.WithDone(ctx, done) + defer cancel() + var ( createCtx context.Context cancelCreate context.CancelFunc @@ -417,22 +421,8 @@ func poolTrace(t *trace.Query) *pool.Trace { onDone(info.Error, info.Attempts) } }, - OnPut: func(info *pool.PutStartInfo) func(*pool.PutDoneInfo) { - onDone := trace.QueryOnPoolPut(t, info.Context, info.Call) - - return func(info *pool.PutDoneInfo) { - onDone(info.Error) - } - }, - OnGet: func(info *pool.GetStartInfo) func(*pool.GetDoneInfo) { - onDone := trace.QueryOnPoolGet(t, info.Context, info.Call) - - return func(info *pool.GetDoneInfo) { - onDone(info.Error) - } - }, OnChange: func(info pool.ChangeInfo) { - trace.QueryOnPoolChange(t, info.Limit, info.Index, info.Idle, info.InUse) + trace.QueryOnPoolChange(t, info.Limit, info.Idle, info.InUse) }, } } diff --git a/internal/query/session.go b/internal/query/session.go index 32ecae2e6..9540dbc1c 100644 --- a/internal/query/session.go +++ b/internal/query/session.go @@ -183,7 +183,9 @@ func (s *Session) closeAndDeleteSession(cancelAttach context.CancelFunc) func(ct } else { ctx, cancel = xcontext.WithCancel(ctx) } - defer cancel() + defer func() { + cancel() + }() if err = deleteSession(ctx, s.grpcClient, s.id); err != nil { return xerrors.WithStackTrace(err) diff --git a/internal/xerrors/operation.go b/internal/xerrors/operation.go index 7020e4f42..e1868c7ef 100644 --- a/internal/xerrors/operation.go +++ b/internal/xerrors/operation.go @@ -3,6 +3,7 @@ package xerrors import ( "errors" "fmt" + "strconv" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Issue" @@ -17,6 +18,7 @@ type operationError struct { code Ydb.StatusIds_StatusCode issues issues address string + nodeID uint32 traceID string } @@ -119,6 +121,10 @@ func (e *operationError) Error() string { b.WriteString(", address = ") b.WriteString(e.address) } + if e.nodeID > 0 { + b.WriteString(", nodeID = ") + b.WriteString(strconv.FormatUint(uint64(e.nodeID), 10)) + } if len(e.issues) > 0 { b.WriteString(", issues = ") b.WriteString(e.issues.String()) diff --git a/internal/xerrors/transport.go b/internal/xerrors/transport.go index b66f7735c..f35c6b39e 100644 --- a/internal/xerrors/transport.go +++ b/internal/xerrors/transport.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "strconv" grpcCodes "google.golang.org/grpc/codes" grpcStatus "google.golang.org/grpc/status" @@ -16,6 +17,7 @@ type transportError struct { status *grpcStatus.Status err error address string + nodeID uint32 traceID string } @@ -47,6 +49,20 @@ func WithAddress(address string) addressOption { return addressOption(address) } +type nodeIDOption uint32 + +func (nodeID nodeIDOption) applyToTransportError(te *transportError) { + te.nodeID = uint32(nodeID) +} + +func (nodeID nodeIDOption) applyToOperationError(oe *operationError) { + oe.nodeID = uint32(nodeID) +} + +func WithNodeID(nodeID uint32) nodeIDOption { + return nodeIDOption(nodeID) +} + func (e *transportError) Error() string { var b bytes.Buffer b.WriteString(e.Name()) @@ -54,6 +70,10 @@ func (e *transportError) Error() string { if len(e.address) > 0 { b.WriteString(fmt.Sprintf(", address: %q", e.address)) } + if e.nodeID > 0 { + b.WriteString(", nodeID = ") + b.WriteString(strconv.FormatUint(uint64(e.nodeID), 10)) + } if len(e.traceID) > 0 { b.WriteString(fmt.Sprintf(", traceID: %q", e.traceID)) } diff --git a/internal/xmath/xmath.go b/internal/xmath/xmath.go index f4fc0fc7e..e3e9ad8c8 100644 --- a/internal/xmath/xmath.go +++ b/internal/xmath/xmath.go @@ -1,21 +1,21 @@ package xmath -func Min[T ordered](v T, values ...T) T { +import "cmp" + +func Min[T cmp.Ordered](v T, values ...T) T { for _, value := range values { if value < v { v = value } } - return v } -func Max[T ordered](v T, values ...T) T { +func Max[T cmp.Ordered](v T, values ...T) T { for _, value := range values { if value > v { v = value } } - return v } diff --git a/internal/xsync/map.go b/internal/xsync/map.go index ab7978d26..9b7c81115 100644 --- a/internal/xsync/map.go +++ b/internal/xsync/map.go @@ -55,14 +55,13 @@ func (m *Map[K, V]) LoadAndDelete(key K) (value V, ok bool) { func (m *Map[K, V]) Range(f func(key K, value V) bool) { m.m.Range(func(k, v any) bool { - return f(k.(K), v.(V)) //nolint:forcetypeassert + return f(k.(K), v.(V)) }) } func (m *Map[K, V]) Clear() { m.m.Range(func(k, v any) bool { m.m.Delete(k) - return true }) } diff --git a/internal/xsync/map_test.go b/internal/xsync/map_test.go index b334d45cb..3c97f2c3d 100644 --- a/internal/xsync/map_test.go +++ b/internal/xsync/map_test.go @@ -11,29 +11,16 @@ func TestMap(t *testing.T) { v, ok := m.Load(1) require.False(t, ok) m.Store(1, "one") - require.NotPanics(t, func() { - v = m.Must(1) - require.Equal(t, "one", v) - }) - require.Panics(t, func() { - v = m.Must(2) - }) - require.Panics(t, func() { - m.m.Store(2, 2) - v = m.Must(2) - }) - m.m.Delete(2) - v, ok = m.LoadAndDelete(2) - require.False(t, ok) - require.Equal(t, "", v) + v, ok = m.Load(1) + require.True(t, ok) + require.Equal(t, "one", v) m.Store(2, "two") v, ok = m.Load(2) require.True(t, ok) require.Equal(t, "two", v) - v, ok = m.LoadAndDelete(1) - require.True(t, ok) - require.Equal(t, "one", v) - require.False(t, m.Has(1)) + m.Delete(1) + v, ok = m.Load(1) + require.False(t, ok) m.Store(3, "three") v, ok = m.Load(3) require.True(t, ok) @@ -49,17 +36,8 @@ func TestMap(t *testing.T) { } else { unexp[key] = value } - return true }) require.Empty(t, exp) require.Empty(t, unexp) - m.Clear() - empty := true - m.Range(func(key int, value string) bool { - empty = false - - return false - }) - require.True(t, empty) } diff --git a/internal/xsync/once.go b/internal/xsync/once.go index 35f5ed0aa..139938145 100644 --- a/internal/xsync/once.go +++ b/internal/xsync/once.go @@ -20,13 +20,14 @@ func OnceFunc(f func(ctx context.Context) error) func(ctx context.Context) error } type Once[T closer.Closer] struct { - f func() T + f func() (T, error) once sync.Once mutex sync.RWMutex t T + err error } -func OnceValue[T closer.Closer](f func() T) *Once[T] { +func OnceValue[T closer.Closer](f func() (T, error)) *Once[T] { return &Once[T]{f: f} } @@ -46,16 +47,25 @@ func (v *Once[T]) Close(ctx context.Context) (err error) { return nil } -func (v *Once[T]) Get() T { +func (v *Once[T]) Get() (T, error) { v.once.Do(func() { v.mutex.Lock() defer v.mutex.Unlock() - v.t = v.f() + v.t, v.err = v.f() }) v.mutex.RLock() defer v.mutex.RUnlock() - return v.t + return v.t, v.err +} + +func (v *Once[T]) Must() T { + t, err := v.Get() + if err != nil { + panic(err) + } + + return t } diff --git a/log/query.go b/log/query.go index b49c75a90..a1900a766 100644 --- a/log/query.go +++ b/log/query.go @@ -153,58 +153,6 @@ func internalQuery( } } }, - OnPoolPut: func(info trace.QueryPoolPutStartInfo) func(trace.QueryPoolPutDoneInfo) { - if d.Details()&trace.QueryPoolEvents == 0 { - return nil - } - ctx := with(*info.Context, TRACE, "ydb", "query", "pool", "put") - l.Log(ctx, "start") - start := time.Now() - - return func(info trace.QueryPoolPutDoneInfo) { - if info.Error == nil { - l.Log(ctx, "done", - latencyField(start), - ) - } else { - lvl := WARN - if !xerrors.IsYdb(info.Error) { - lvl = DEBUG - } - l.Log(WithLevel(ctx, lvl), "failed", - latencyField(start), - Error(info.Error), - versionField(), - ) - } - } - }, - OnPoolGet: func(info trace.QueryPoolGetStartInfo) func(trace.QueryPoolGetDoneInfo) { - if d.Details()&trace.QueryPoolEvents == 0 { - return nil - } - ctx := with(*info.Context, TRACE, "ydb", "query", "pool", "get") - l.Log(ctx, "start") - start := time.Now() - - return func(info trace.QueryPoolGetDoneInfo) { - if info.Error == nil { - l.Log(ctx, "done", - latencyField(start), - ) - } else { - lvl := WARN - if !xerrors.IsYdb(info.Error) { - lvl = DEBUG - } - l.Log(WithLevel(ctx, lvl), "failed", - latencyField(start), - Error(info.Error), - versionField(), - ) - } - } - }, OnDo: func(info trace.QueryDoStartInfo) func(trace.QueryDoDoneInfo) { if d.Details()&trace.QueryEvents == 0 { return nil diff --git a/metrics/query.go b/metrics/query.go index af8d1998c..47da4ebb6 100644 --- a/metrics/query.go +++ b/metrics/query.go @@ -42,7 +42,6 @@ func query(config Config) (t trace.Query) { sizeConfig := poolConfig.WithSystem("size") limit := sizeConfig.GaugeVec("limit") idle := sizeConfig.GaugeVec("idle") - index := sizeConfig.GaugeVec("index") inUse := sizeConfig.WithSystem("in").GaugeVec("use") t.OnPoolChange = func(stats trace.QueryPoolChange) { @@ -53,7 +52,6 @@ func query(config Config) (t trace.Query) { limit.With(nil).Set(float64(stats.Limit)) idle.With(nil).Set(float64(stats.Idle)) inUse.With(nil).Set(float64(stats.InUse)) - index.With(nil).Set(float64(stats.Index)) } } } diff --git a/options.go b/options.go index 5bcdf0109..9631d3b7c 100644 --- a/options.go +++ b/options.go @@ -13,7 +13,6 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/credentials" balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" "github.com/ydb-platform/ydb-go-sdk/v3/internal/certificates" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" coordinationConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/coordination/config" discoveryConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery/config" "github.com/ydb-platform/ydb-go-sdk/v3/internal/dsn" @@ -759,10 +758,10 @@ func withOnClose(onClose func(c *Driver)) Option { } } -func withConnPool(pool *conn.Pool) Option { - return func(ctx context.Context, c *Driver) error { - c.pool = pool - - return pool.Take(ctx) - } -} +//func withConnPool(pool *conn.Pool) Option { +// return func(ctx context.Context, c *Driver) error { +// c.pool = pool +// +// return pool.Take(ctx) +// } +//} diff --git a/query/client.go b/query/client.go index ecc9e37f0..06c937109 100644 --- a/query/client.go +++ b/query/client.go @@ -4,7 +4,7 @@ import ( "context" "github.com/ydb-platform/ydb-go-sdk/v3/internal/closer" - poolStats "github.com/ydb-platform/ydb-go-sdk/v3/internal/pool/stats" + poolStats "github.com/ydb-platform/ydb-go-sdk/v3/internal/pool" "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/options" "github.com/ydb-platform/ydb-go-sdk/v3/retry/budget" "github.com/ydb-platform/ydb-go-sdk/v3/trace" diff --git a/query/stats.go b/query/stats.go index 871320415..b408c8626 100644 --- a/query/stats.go +++ b/query/stats.go @@ -1,7 +1,7 @@ package query import ( - "github.com/ydb-platform/ydb-go-sdk/v3/internal/pool/stats" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/pool" ) // Stats returns stats of query client pool @@ -9,6 +9,6 @@ import ( // Deprecated: use client.Stats() method instead // Will be removed after Jan 2025. // Read about versioning policy: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#deprecated -func Stats(client Client) (*stats.Stats, error) { +func Stats(client Client) (*pool.Stats, error) { return client.Stats(), nil } diff --git a/sugar/query_test.go b/sugar/query_test.go index 195a4b6f5..2d94576a4 100644 --- a/sugar/query_test.go +++ b/sugar/query_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/pool/stats" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/pool" internalQuery "github.com/ydb-platform/ydb-go-sdk/v3/internal/query" "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/options" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest" @@ -25,7 +25,7 @@ type mockReadResultSetClient struct { rs query.ResultSet } -func (c *mockReadResultSetClient) Stats() *stats.Stats { +func (c *mockReadResultSetClient) Stats() *pool.Stats { return nil } @@ -63,7 +63,7 @@ type mockReadRowClient struct { row query.Row } -func (c *mockReadRowClient) Stats() *stats.Stats { +func (c *mockReadRowClient) Stats() *pool.Stats { return nil } diff --git a/trace/driver.go b/trace/driver.go index 31a55d088..d7b995b22 100644 --- a/trace/driver.go +++ b/trace/driver.go @@ -9,6 +9,8 @@ import ( "fmt" "strings" "time" + + "google.golang.org/grpc/connectivity" ) type ( @@ -129,14 +131,6 @@ func (m Method) Split() (service, method string) { return strings.TrimPrefix(string(m[:i]), "/"), string(m[i+1:]) } -// Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -type ConnState interface { - fmt.Stringer - - IsValid() bool - Code() int -} - // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals type EndpointInfo interface { fmt.Stringer @@ -164,11 +158,11 @@ type ( Context *context.Context Call call Endpoint EndpointInfo - State ConnState + State connectivity.State } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverConnStateChangeDoneInfo struct { - State ConnState + State connectivity.State } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverResolveStartInfo struct { @@ -321,12 +315,12 @@ type ( Context *context.Context Call call Endpoint EndpointInfo - State ConnState + State connectivity.State Cause error } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverConnBanDoneInfo struct { - State ConnState + State connectivity.State } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverConnAllowStartInfo struct { @@ -337,11 +331,11 @@ type ( Context *context.Context Call call Endpoint EndpointInfo - State ConnState + State connectivity.State } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverConnAllowDoneInfo struct { - State ConnState + State connectivity.State } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverConnInvokeStartInfo struct { @@ -359,7 +353,7 @@ type ( Error error Issues []Issue OpID string - State ConnState + State connectivity.State Metadata map[string][]string } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals @@ -376,7 +370,7 @@ type ( // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverConnNewStreamDoneInfo struct { Error error - State ConnState + State connectivity.State } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverConnStreamRecvMsgStartInfo struct { diff --git a/trace/driver_gtrace.go b/trace/driver_gtrace.go index 32c008d09..ffb3ccc9e 100644 --- a/trace/driver_gtrace.go +++ b/trace/driver_gtrace.go @@ -4,6 +4,8 @@ package trace import ( "context" + + "google.golang.org/grpc/connectivity" ) // driverComposeOptions is a holder of options @@ -1340,49 +1342,49 @@ func DriverOnResolve(t *Driver, call call, target string, resolved []string) fun } } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func DriverOnConnStateChange(t *Driver, c *context.Context, call call, endpoint EndpointInfo, state ConnState) func(state ConnState) { +func DriverOnConnStateChange(t *Driver, c *context.Context, call call, endpoint EndpointInfo, s connectivity.State) func(connectivity.State) { var p DriverConnStateChangeStartInfo p.Context = c p.Call = call p.Endpoint = endpoint - p.State = state + p.State = s res := t.onConnStateChange(p) - return func(state ConnState) { + return func(s connectivity.State) { var p DriverConnStateChangeDoneInfo - p.State = state + p.State = s res(p) } } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func DriverOnConnInvoke(t *Driver, c *context.Context, call call, endpoint EndpointInfo, m Method) func(_ error, issues []Issue, opID string, state ConnState, metadata map[string][]string) { +func DriverOnConnInvoke(t *Driver, c *context.Context, call call, endpoint EndpointInfo, m Method) func(_ error, issues []Issue, opID string, _ connectivity.State, metadata map[string][]string) { var p DriverConnInvokeStartInfo p.Context = c p.Call = call p.Endpoint = endpoint p.Method = m res := t.onConnInvoke(p) - return func(e error, issues []Issue, opID string, state ConnState, metadata map[string][]string) { + return func(e error, issues []Issue, opID string, s connectivity.State, metadata map[string][]string) { var p DriverConnInvokeDoneInfo p.Error = e p.Issues = issues p.OpID = opID - p.State = state + p.State = s p.Metadata = metadata res(p) } } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func DriverOnConnNewStream(t *Driver, c *context.Context, call call, endpoint EndpointInfo, m Method) func(_ error, state ConnState) { +func DriverOnConnNewStream(t *Driver, c *context.Context, call call, endpoint EndpointInfo, m Method) func(error, connectivity.State) { var p DriverConnNewStreamStartInfo p.Context = c p.Call = call p.Endpoint = endpoint p.Method = m res := t.onConnNewStream(p) - return func(e error, state ConnState) { + return func(e error, s connectivity.State) { var p DriverConnNewStreamDoneInfo p.Error = e - p.State = state + p.State = s res(p) } } @@ -1444,31 +1446,31 @@ func DriverOnConnDial(t *Driver, c *context.Context, call call, endpoint Endpoin } } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func DriverOnConnBan(t *Driver, c *context.Context, call call, endpoint EndpointInfo, state ConnState, cause error) func(state ConnState) { +func DriverOnConnBan(t *Driver, c *context.Context, call call, endpoint EndpointInfo, s connectivity.State, cause error) func(connectivity.State) { var p DriverConnBanStartInfo p.Context = c p.Call = call p.Endpoint = endpoint - p.State = state + p.State = s p.Cause = cause res := t.onConnBan(p) - return func(state ConnState) { + return func(s connectivity.State) { var p DriverConnBanDoneInfo - p.State = state + p.State = s res(p) } } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func DriverOnConnAllow(t *Driver, c *context.Context, call call, endpoint EndpointInfo, state ConnState) func(state ConnState) { +func DriverOnConnAllow(t *Driver, c *context.Context, call call, endpoint EndpointInfo, s connectivity.State) func(connectivity.State) { var p DriverConnAllowStartInfo p.Context = c p.Call = call p.Endpoint = endpoint - p.State = state + p.State = s res := t.onConnAllow(p) - return func(state ConnState) { + return func(s connectivity.State) { var p DriverConnAllowDoneInfo - p.State = state + p.State = s res(p) } } diff --git a/trace/query.go b/trace/query.go index 59a41eeb8..97869bfc7 100644 --- a/trace/query.go +++ b/trace/query.go @@ -38,10 +38,6 @@ type ( // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals OnPoolWith func(QueryPoolWithStartInfo) func(QueryPoolWithDoneInfo) // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals - OnPoolPut func(QueryPoolPutStartInfo) func(QueryPoolPutDoneInfo) - // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals - OnPoolGet func(QueryPoolGetStartInfo) func(QueryPoolGetDoneInfo) - // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals OnPoolChange func(QueryPoolChange) // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals @@ -423,35 +419,8 @@ type ( Attempts int } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals - QueryPoolPutStartInfo struct { - // Context make available context in trace callback function. - // Pointer to context provide replacement of context in trace callback function. - // Warning: concurrent access to pointer on client side must be excluded. - // Safe replacement of context are provided only inside callback function - Context *context.Context - Call call - } - // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals - QueryPoolPutDoneInfo struct { - Error error - } - // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals - QueryPoolGetStartInfo struct { - // Context make available context in trace callback function. - // Pointer to context provide replacement of context in trace callback function. - // Warning: concurrent access to pointer on client side must be excluded. - // Safe replacement of context are provided only inside callback function - Context *context.Context - Call call - } - // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals - QueryPoolGetDoneInfo struct { - Error error - } - // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals QueryPoolChange struct { Limit int - Index int Idle int InUse int } diff --git a/trace/query_gtrace.go b/trace/query_gtrace.go index c6f953aed..b2aac1664 100644 --- a/trace/query_gtrace.go +++ b/trace/query_gtrace.go @@ -245,76 +245,6 @@ func (t *Query) Compose(x *Query, opts ...QueryComposeOption) *Query { } } } - { - h1 := t.OnPoolPut - h2 := x.OnPoolPut - ret.OnPoolPut = func(q QueryPoolPutStartInfo) func(QueryPoolPutDoneInfo) { - if options.panicCallback != nil { - defer func() { - if e := recover(); e != nil { - options.panicCallback(e) - } - }() - } - var r, r1 func(QueryPoolPutDoneInfo) - if h1 != nil { - r = h1(q) - } - if h2 != nil { - r1 = h2(q) - } - return func(q QueryPoolPutDoneInfo) { - if options.panicCallback != nil { - defer func() { - if e := recover(); e != nil { - options.panicCallback(e) - } - }() - } - if r != nil { - r(q) - } - if r1 != nil { - r1(q) - } - } - } - } - { - h1 := t.OnPoolGet - h2 := x.OnPoolGet - ret.OnPoolGet = func(q QueryPoolGetStartInfo) func(QueryPoolGetDoneInfo) { - if options.panicCallback != nil { - defer func() { - if e := recover(); e != nil { - options.panicCallback(e) - } - }() - } - var r, r1 func(QueryPoolGetDoneInfo) - if h1 != nil { - r = h1(q) - } - if h2 != nil { - r1 = h2(q) - } - return func(q QueryPoolGetDoneInfo) { - if options.panicCallback != nil { - defer func() { - if e := recover(); e != nil { - options.panicCallback(e) - } - }() - } - if r != nil { - r(q) - } - if r1 != nil { - r1(q) - } - } - } - } { h1 := t.OnPoolChange h2 := x.OnPoolChange @@ -1091,36 +1021,6 @@ func (t *Query) onPoolWith(q QueryPoolWithStartInfo) func(QueryPoolWithDoneInfo) } return res } -func (t *Query) onPoolPut(q QueryPoolPutStartInfo) func(QueryPoolPutDoneInfo) { - fn := t.OnPoolPut - if fn == nil { - return func(QueryPoolPutDoneInfo) { - return - } - } - res := fn(q) - if res == nil { - return func(QueryPoolPutDoneInfo) { - return - } - } - return res -} -func (t *Query) onPoolGet(q QueryPoolGetStartInfo) func(QueryPoolGetDoneInfo) { - fn := t.OnPoolGet - if fn == nil { - return func(QueryPoolGetDoneInfo) { - return - } - } - res := fn(q) - if res == nil { - return func(QueryPoolGetDoneInfo) { - return - } - } - return res -} func (t *Query) onPoolChange(q QueryPoolChange) { fn := t.OnPoolChange if fn == nil { @@ -1486,34 +1386,9 @@ func QueryOnPoolWith(t *Query, c *context.Context, call call) func(_ error, atte } } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func QueryOnPoolPut(t *Query, c *context.Context, call call) func(error) { - var p QueryPoolPutStartInfo - p.Context = c - p.Call = call - res := t.onPoolPut(p) - return func(e error) { - var p QueryPoolPutDoneInfo - p.Error = e - res(p) - } -} -// Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func QueryOnPoolGet(t *Query, c *context.Context, call call) func(error) { - var p QueryPoolGetStartInfo - p.Context = c - p.Call = call - res := t.onPoolGet(p) - return func(e error) { - var p QueryPoolGetDoneInfo - p.Error = e - res(p) - } -} -// Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func QueryOnPoolChange(t *Query, limit int, index int, idle int, inUse int) { +func QueryOnPoolChange(t *Query, limit int, idle int, inUse int) { var p QueryPoolChange p.Limit = limit - p.Index = index p.Idle = idle p.InUse = inUse t.onPoolChange(p) diff --git a/with.go b/with.go index c226db991..dcdd758c2 100644 --- a/with.go +++ b/with.go @@ -28,7 +28,7 @@ func (d *Driver) with(ctx context.Context, opts ...Option) (*Driver, uint64, err delete(d.children, id) }), - withConnPool(d.pool), + //withConnPool(d.pool), ), opts..., )...,