diff --git a/cassandra_test.go b/cassandra_test.go index 3b0c61053..6850d26c1 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -3303,7 +3303,6 @@ func TestUnsetColBatch(t *testing.T) { } var id, mInt, count int var mText string - if err := session.Query("SELECT count(*) FROM gocql_test.batchUnsetInsert;").Scan(&count); err != nil { t.Fatalf("Failed to select with err: %v", err) } else if count != 2 { @@ -3338,3 +3337,35 @@ func TestQuery_NamedValues(t *testing.T) { t.Fatal(err) } } + +func TestQuery_SetHost(t *testing.T) { + // This test ensures that queries are sent to the specified host only + + session := createSession(t) + defer session.Close() + + hosts, err := session.GetHosts() + if err != nil { + t.Fatal(err) + } + + for _, expectedHost := range hosts { + const iterations = 5 + for i := 0; i < iterations; i++ { + var actualHostID string + err := session.Query("SELECT host_id FROM system.local"). + SetHost(expectedHost). + Scan(&actualHostID) + if err != nil { + t.Fatal(err) + } + + if expectedHost.HostID() != actualHostID { + t.Fatalf("Expected query to be executed on host %s, but it was executed on %s", + expectedHost.HostID(), + actualHostID, + ) + } + } + } +} diff --git a/query_executor.go b/query_executor.go index 03687361a..c623fc9dd 100644 --- a/query_executor.go +++ b/query_executor.go @@ -41,6 +41,7 @@ type ExecutableQuery interface { Keyspace() string Table() string IsIdempotent() bool + GetHost() *HostInfo withContext(context.Context) ExecutableQuery @@ -83,12 +84,27 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S } func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { - hostIter := q.policy.Pick(qry) + var hostIter NextHost + + // checking if the host is specified for the query, + // if it is, the query should be executed at the specified host + host := qry.GetHost() + if host != nil { + hostIter = func() SelectedHost { + return (*selectedHost)(host) + } + } + + // if host is not specified for the query, + // then a host will be picked by HostSelectionPolicy + if hostIter == nil { + hostIter = q.policy.Pick(qry) + } // check if the query is not marked as idempotent, if // it is, we force the policy to NonSpeculative sp := qry.speculativeExecutionPolicy() - if !qry.IsIdempotent() || sp.Attempts() == 0 { + if host != nil || !qry.IsIdempotent() || sp.Attempts() == 0 { return q.do(qry.Context(), qry, hostIter), nil } @@ -129,12 +145,13 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter NextHost) *Iter { selectedHost := hostIter() rt := qry.retryPolicy() + specifiedHost := qry.GetHost() var lastErr error var iter *Iter for selectedHost != nil { host := selectedHost.Info() - if host == nil || !host.IsUp() { + if (host == nil || !host.IsUp()) && specifiedHost == nil { selectedHost = hostIter() continue } @@ -166,7 +183,9 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter Ne // Exit if the query was successful // or no retry policy defined or retry attempts were reached - if iter.err == nil || !qry.IsIdempotent() || rt == nil || !rt.Attempt(qry) { + // Also, if there is specified host for the query to be executed on + // and query execution is failed we should exit + if iter.err == nil || specifiedHost != nil || !qry.IsIdempotent() || rt == nil || !rt.Attempt(qry) { return iter } lastErr = iter.err diff --git a/session.go b/session.go index b884735c2..29ba2d11e 100644 --- a/session.go +++ b/session.go @@ -936,6 +936,10 @@ type Query struct { // routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex. routingInfo *queryRoutingInfo + + // host specifies the host on which the query should be executed. + // If it is nil, then the host is picked by HostSelectionPolicy + host *HostInfo } type queryRoutingInfo struct { @@ -1423,6 +1427,18 @@ func (q *Query) releaseAfterExecution() { q.decRefCount() } +// SetHosts allows to define on which host the query should be executed. +// If host == nil, then the HostSelectionPolicy will be used to pick a host. +func (q *Query) SetHost(host *HostInfo) *Query { + q.host = host + return q +} + +// GetHost returns host on which query should be executed. +func (q *Query) GetHost() *HostInfo { + return q.host +} + // Iter represents an iterator that can be used to iterate over all rows that // were returned by a query. The iterator might send additional queries to the // database during the iteration if paging was enabled. @@ -2030,6 +2046,10 @@ func (b *Batch) releaseAfterExecution() { // that would race with speculative executions. } +func (b *Batch) GetHost() *HostInfo { + return nil +} + type BatchType byte const ( @@ -2162,6 +2182,15 @@ func (t *traceWriter) Trace(traceId []byte) { } } +// GetHosts returns a list of hosts found via queries to system.local and system.peers +func (s *Session) GetHosts() ([]*HostInfo, error) { + hosts, _, err := s.hostSource.GetHosts() + if err != nil { + return nil, err + } + return hosts, nil +} + type ObservedQuery struct { Keyspace string Statement string