From d83e8a2674e660497b60fed23ab8ec3d3b772085 Mon Sep 17 00:00:00 2001 From: Bohdan Siryk Date: Thu, 21 Nov 2024 13:12:53 +0200 Subject: [PATCH] CASSGO-4 SetHost API for Query This patch provides mechanism that allows users to specify on which node the query will be executed. It is not a tipycal use case, but it makes sense with virtual tables which are available since C* 5.0.0. Patch by Bohdan Siryk; Reviewed by for CASSGO-4 --- cassandra_test.go | 49 ++++++++++++++++++++++++++++++++++++++++++++--- query_executor.go | 31 ++++++++++++++++++++++++++---- session.go | 29 ++++++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 7 deletions(-) diff --git a/cassandra_test.go b/cassandra_test.go index 3b0c61053..8afbe7600 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -32,7 +32,6 @@ import ( "context" "errors" "fmt" - "github.com/stretchr/testify/require" "io" "math" "math/big" @@ -45,7 +44,9 @@ import ( "time" "unicode" - inf "gopkg.in/inf.v0" + "github.com/stretchr/testify/require" + + "gopkg.in/inf.v0" ) func TestEmptyHosts(t *testing.T) { @@ -3303,7 +3304,6 @@ func TestUnsetColBatch(t *testing.T) { } var id, mInt, count int var mText string - if err := session.Query("SELECT count(*) FROM gocql_test.batchUnsetInsert;").Scan(&count); err != nil { t.Fatalf("Failed to select with err: %v", err) } else if count != 2 { @@ -3338,3 +3338,46 @@ func TestQuery_NamedValues(t *testing.T) { t.Fatal(err) } } + +func TestQuery_SetHost(t *testing.T) { + // This test ensures that queries are sent to the specified host only + + session := createSession(t) + defer session.Close() + + hosts, err := session.GetHosts() + if err != nil { + t.Fatal(err) + } + + for _, expectedHost := range hosts { + const iterations = 5 + for i := 0; i < iterations; i++ { + var actualHostID string + err := session.Query("SELECT host_id FROM system.local"). + SetHost(expectedHost). + Scan(&actualHostID) + if err != nil { + t.Fatal(err) + } + + if expectedHost.HostID() != actualHostID { + t.Fatalf("Expected query to be executed on host %s, but it was executed on %s", + expectedHost.HostID(), + actualHostID, + ) + } + } + } + + // ensuring that the driver properly handles the case + // when specified host for the query is down + host := hosts[0] + host.state = NodeDown + err = session.Query("SELECT host_id FROM system.local"). + SetHost(host). + Exec() + if !errors.Is(err, ErrNoConnections) { + t.Fatalf("Expected error to be: %v, but got %v", ErrNoConnections, err) + } +} diff --git a/query_executor.go b/query_executor.go index 03687361a..60d71bf0d 100644 --- a/query_executor.go +++ b/query_executor.go @@ -41,6 +41,7 @@ type ExecutableQuery interface { Keyspace() string Table() string IsIdempotent() bool + GetHost() *HostInfo withContext(context.Context) ExecutableQuery @@ -83,12 +84,27 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S } func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { - hostIter := q.policy.Pick(qry) + var hostIter NextHost + + // checking if the host is specified for the query, + // if it is, the query should be executed at the specified host + host := qry.GetHost() + if host != nil { + hostIter = func() SelectedHost { + return (*selectedHost)(host) + } + } + + // if host is not specified for the query, + // then a host will be picked by HostSelectionPolicy + if hostIter == nil { + hostIter = q.policy.Pick(qry) + } // check if the query is not marked as idempotent, if // it is, we force the policy to NonSpeculative sp := qry.speculativeExecutionPolicy() - if !qry.IsIdempotent() || sp.Attempts() == 0 { + if host != nil || !qry.IsIdempotent() || sp.Attempts() == 0 { return q.do(qry.Context(), qry, hostIter), nil } @@ -129,12 +145,17 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter NextHost) *Iter { selectedHost := hostIter() rt := qry.retryPolicy() + specifiedHost := qry.GetHost() var lastErr error var iter *Iter for selectedHost != nil { host := selectedHost.Info() - if host == nil || !host.IsUp() { + if specifiedHost != nil && host != nil && !host.IsUp() { + return &Iter{err: ErrNoConnections} + } + + if (host == nil || !host.IsUp()) && specifiedHost == nil { selectedHost = hostIter() continue } @@ -166,7 +187,9 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter Ne // Exit if the query was successful // or no retry policy defined or retry attempts were reached - if iter.err == nil || !qry.IsIdempotent() || rt == nil || !rt.Attempt(qry) { + // Also, if there is specified host for the query to be executed on + // and query execution is failed we should exit + if iter.err == nil || specifiedHost != nil || !qry.IsIdempotent() || rt == nil || !rt.Attempt(qry) { return iter } lastErr = iter.err diff --git a/session.go b/session.go index b884735c2..29ba2d11e 100644 --- a/session.go +++ b/session.go @@ -936,6 +936,10 @@ type Query struct { // routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex. routingInfo *queryRoutingInfo + + // host specifies the host on which the query should be executed. + // If it is nil, then the host is picked by HostSelectionPolicy + host *HostInfo } type queryRoutingInfo struct { @@ -1423,6 +1427,18 @@ func (q *Query) releaseAfterExecution() { q.decRefCount() } +// SetHosts allows to define on which host the query should be executed. +// If host == nil, then the HostSelectionPolicy will be used to pick a host. +func (q *Query) SetHost(host *HostInfo) *Query { + q.host = host + return q +} + +// GetHost returns host on which query should be executed. +func (q *Query) GetHost() *HostInfo { + return q.host +} + // Iter represents an iterator that can be used to iterate over all rows that // were returned by a query. The iterator might send additional queries to the // database during the iteration if paging was enabled. @@ -2030,6 +2046,10 @@ func (b *Batch) releaseAfterExecution() { // that would race with speculative executions. } +func (b *Batch) GetHost() *HostInfo { + return nil +} + type BatchType byte const ( @@ -2162,6 +2182,15 @@ func (t *traceWriter) Trace(traceId []byte) { } } +// GetHosts returns a list of hosts found via queries to system.local and system.peers +func (s *Session) GetHosts() ([]*HostInfo, error) { + hosts, _, err := s.hostSource.GetHosts() + if err != nil { + return nil, err + } + return hosts, nil +} + type ObservedQuery struct { Keyspace string Statement string