diff --git a/CHANGELOG.md b/CHANGELOG.md index 67c88a141..2f27a9b87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- SetHost API for Query (CASSGO-4) + ### Changed - Don't restrict server authenticator unless PasswordAuthentictor.AllowedAuthenticators is provided (CASSGO-19) diff --git a/cassandra_test.go b/cassandra_test.go index ec6969190..482583573 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -32,7 +32,6 @@ import ( "context" "errors" "fmt" - "github.com/stretchr/testify/require" "io" "math" "math/big" @@ -45,6 +44,8 @@ import ( "time" "unicode" + "github.com/stretchr/testify/require" + "gopkg.in/inf.v0" ) @@ -3303,7 +3304,6 @@ func TestUnsetColBatch(t *testing.T) { } var id, mInt, count int var mText string - if err := session.Query("SELECT count(*) FROM gocql_test.batchUnsetInsert;").Scan(&count); err != nil { t.Fatalf("Failed to select with err: %v", err) } else if count != 2 { @@ -3338,3 +3338,46 @@ func TestQuery_NamedValues(t *testing.T) { t.Fatal(err) } } + +func TestQuery_SetHost(t *testing.T) { + // This test ensures that queries are sent to the specified host only + + session := createSession(t) + defer session.Close() + + hosts, err := session.GetHosts() + if err != nil { + t.Fatal(err) + } + + for _, expectedHost := range hosts { + const iterations = 5 + for i := 0; i < iterations; i++ { + var actualHostID string + err := session.Query("SELECT host_id FROM system.local"). + SetHost(expectedHost). + Scan(&actualHostID) + if err != nil { + t.Fatal(err) + } + + if expectedHost.HostID() != actualHostID { + t.Fatalf("Expected query to be executed on host %s, but it was executed on %s", + expectedHost.HostID(), + actualHostID, + ) + } + } + } + + // ensuring that the driver properly handles the case + // when specified host for the query is down + host := hosts[0] + host.state = NodeDown + err = session.Query("SELECT host_id FROM system.local"). + SetHost(host). + Exec() + if !errors.Is(err, ErrNoConnections) { + t.Fatalf("Expected error to be: %v, but got %v", ErrNoConnections, err) + } +} diff --git a/query_executor.go b/query_executor.go index d6be02e53..ca46a5563 100644 --- a/query_executor.go +++ b/query_executor.go @@ -41,6 +41,7 @@ type ExecutableQuery interface { Keyspace() string Table() string IsIdempotent() bool + GetHost() *HostInfo withContext(context.Context) ExecutableQuery @@ -83,12 +84,27 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S } func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { - hostIter := q.policy.Pick(qry) + var hostIter NextHost + + // checking if the host is specified for the query, + // if it is, the query should be executed at the specified host + host := qry.GetHost() + if host != nil { + hostIter = func() SelectedHost { + return (*selectedHost)(host) + } + } + + // if host is not specified for the query, + // then a host will be picked by HostSelectionPolicy + if hostIter == nil { + hostIter = q.policy.Pick(qry) + } // check if the query is not marked as idempotent, if // it is, we force the policy to NonSpeculative sp := qry.speculativeExecutionPolicy() - if !qry.IsIdempotent() || sp.Attempts() == 0 { + if host != nil || !qry.IsIdempotent() || sp.Attempts() == 0 { return q.do(qry.Context(), qry, hostIter), nil } @@ -129,12 +145,17 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter NextHost) *Iter { selectedHost := hostIter() rt := qry.retryPolicy() + specifiedHost := qry.GetHost() var lastErr error var iter *Iter for selectedHost != nil { host := selectedHost.Info() - if host == nil || !host.IsUp() { + if specifiedHost != nil && host != nil && !host.IsUp() { + return &Iter{err: ErrNoConnections} + } + + if (host == nil || !host.IsUp()) && specifiedHost == nil { selectedHost = hostIter() continue } @@ -166,7 +187,9 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter Ne // Exit if the query was successful // or query is not idempotent or no retry policy defined - if iter.err == nil || !qry.IsIdempotent() || rt == nil { + // Also, if there is specified host for the query to be executed on + // and query execution is failed we should exit + if iter.err == nil || specifiedHost != nil || !qry.IsIdempotent() || rt == nil { return iter } diff --git a/session.go b/session.go index d04a13672..f36b84e5a 100644 --- a/session.go +++ b/session.go @@ -943,6 +943,10 @@ type Query struct { // routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex. routingInfo *queryRoutingInfo + + // host specifies the host on which the query should be executed. + // If it is nil, then the host is picked by HostSelectionPolicy + host *HostInfo } type queryRoutingInfo struct { @@ -1430,6 +1434,18 @@ func (q *Query) releaseAfterExecution() { q.decRefCount() } +// SetHosts allows to define on which host the query should be executed. +// If host == nil, then the HostSelectionPolicy will be used to pick a host. +func (q *Query) SetHost(host *HostInfo) *Query { + q.host = host + return q +} + +// GetHost returns host on which query should be executed. +func (q *Query) GetHost() *HostInfo { + return q.host +} + // Iter represents an iterator that can be used to iterate over all rows that // were returned by a query. The iterator might send additional queries to the // database during the iteration if paging was enabled. @@ -2045,6 +2061,10 @@ func (b *Batch) releaseAfterExecution() { // that would race with speculative executions. } +func (b *Batch) GetHost() *HostInfo { + return nil +} + type BatchType byte const ( @@ -2177,6 +2197,15 @@ func (t *traceWriter) Trace(traceId []byte) { } } +// GetHosts returns a list of hosts found via queries to system.local and system.peers +func (s *Session) GetHosts() ([]*HostInfo, error) { + hosts, _, err := s.hostSource.GetHosts() + if err != nil { + return nil, err + } + return hosts, nil +} + type ObservedQuery struct { Keyspace string Statement string