diff --git a/pkg/core/workflow_execute.go b/pkg/core/workflow_execute.go index 22b1b813f7..55d19dd677 100644 --- a/pkg/core/workflow_execute.go +++ b/pkg/core/workflow_execute.go @@ -96,10 +96,10 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan firstMatched = true } } + if w.Options.HostErrorsCache != nil { + w.Options.HostErrorsCache.MarkFailedOrRemove(w.Options.ProtocolType.String(), ctx.Input, err) + } if err != nil { - if w.Options.HostErrorsCache != nil { - w.Options.HostErrorsCache.MarkFailed(w.Options.ProtocolType.String(), ctx.Input, err) - } if len(template.Executers) == 1 { mainErr = err } else { diff --git a/pkg/protocols/common/hosterrorscache/hosterrorscache.go b/pkg/protocols/common/hosterrorscache/hosterrorscache.go index bca4803e8a..3943eef7e6 100644 --- a/pkg/protocols/common/hosterrorscache/hosterrorscache.go +++ b/pkg/protocols/common/hosterrorscache/hosterrorscache.go @@ -1,6 +1,7 @@ package hosterrorscache import ( + "errors" "net" "net/url" "regexp" @@ -20,10 +21,12 @@ import ( // CacheInterface defines the signature of the hosterrorscache so that // users of Nuclei as embedded lib may implement their own cache type CacheInterface interface { - SetVerbose(verbose bool) // log verbosely - Close() // close the cache - Check(protoType string, ctx *contextargs.Context) bool // return true if the host should be skipped - MarkFailed(protoType string, ctx *contextargs.Context, err error) // record a failure (and cause) for the host + SetVerbose(verbose bool) // log verbosely + Close() // close the cache + Check(protoType string, ctx *contextargs.Context) bool // return true if the host should be skipped + Remove(ctx *contextargs.Context) // remove a host from the cache + MarkFailed(protoType string, ctx *contextargs.Context, err error) // record a failure (and cause) for the host + MarkFailedOrRemove(protoType string, ctx *contextargs.Context, err error) // record a failure (and cause) for the host or remove it } var ( @@ -47,16 +50,20 @@ type cacheItem struct { errors atomic.Int32 isPermanentErr bool cause error // optional cause + mu sync.Mutex } const DefaultMaxHostsCount = 10000 // New returns a new host max errors cache func New(maxHostError, maxHostsCount int, trackError []string) *Cache { - gc := gcache.New[string, *cacheItem](maxHostsCount). - ARC(). - Build() - return &Cache{failedTargets: gc, MaxHostError: maxHostError, TrackError: trackError} + gc := gcache.New[string, *cacheItem](maxHostsCount).ARC().Build() + + return &Cache{ + failedTargets: gc, + MaxHostError: maxHostError, + TrackError: trackError, + } } // SetVerbose sets the cache to log at verbose level @@ -118,47 +125,108 @@ func (c *Cache) NormalizeCacheValue(value string) string { func (c *Cache) Check(protoType string, ctx *contextargs.Context) bool { finalValue := c.GetKeyFromContext(ctx, nil) - existingCacheItem, err := c.failedTargets.GetIFPresent(finalValue) + cache, err := c.failedTargets.GetIFPresent(finalValue) if err != nil { return false } - if existingCacheItem.isPermanentErr { + + cache.mu.Lock() + defer cache.mu.Unlock() + + if cache.isPermanentErr { // skipping permanent errors is expected so verbose instead of info - gologger.Verbose().Msgf("Skipped %s from target list as found unresponsive permanently: %s", finalValue, existingCacheItem.cause) + gologger.Verbose().Msgf("Skipped %s from target list as found unresponsive permanently: %s", finalValue, cache.cause) return true } - if existingCacheItem.errors.Load() >= int32(c.MaxHostError) { - existingCacheItem.Do(func() { - gologger.Info().Msgf("Skipped %s from target list as found unresponsive %d times", finalValue, existingCacheItem.errors.Load()) + if cache.errors.Load() >= int32(c.MaxHostError) { + cache.Do(func() { + gologger.Info().Msgf("Skipped %s from target list as found unresponsive %d times", finalValue, cache.errors.Load()) }) return true } + return false } +// Remove removes a host from the cache +func (c *Cache) Remove(ctx *contextargs.Context) { + key := c.GetKeyFromContext(ctx, nil) + _ = c.failedTargets.Remove(key) // remove even the cache is not present +} + // MarkFailed marks a host as failed previously +// +// Deprecated: Use MarkFailedOrRemove instead. func (c *Cache) MarkFailed(protoType string, ctx *contextargs.Context, err error) { - if !c.checkError(protoType, err) { + if err == nil { return } - finalValue := c.GetKeyFromContext(ctx, err) - existingCacheItem, err := c.failedTargets.GetIFPresent(finalValue) - if err != nil || existingCacheItem == nil { - newItem := &cacheItem{errors: atomic.Int32{}} - newItem.errors.Store(1) - if errkit.IsKind(err, errkit.ErrKindNetworkPermanent) { - // skip this address altogether - // permanent errors are always permanent hence this is created once - // and never updated so no need to synchronize - newItem.isPermanentErr = true - newItem.cause = err - } - _ = c.failedTargets.Set(finalValue, newItem) + + c.MarkFailedOrRemove(protoType, ctx, err) +} + +// MarkFailedOrRemove marks a host as failed previously or removes it +func (c *Cache) MarkFailedOrRemove(protoType string, ctx *contextargs.Context, err error) { + if err != nil && !c.checkError(protoType, err) { return } - existingCacheItem.errors.Add(1) - _ = c.failedTargets.Set(finalValue, existingCacheItem) + + if err == nil { + // Remove the host from cache + // + // NOTE(dwisiswant0): The decision was made to completely remove the + // cached entry for the host instead of simply decrementing the error + // count (using `(atomic.Int32).Swap` to update the value to `N-1`). + // This approach was chosen because the error handling logic operates + // concurrently, and decrementing the count could lead to UB (unexpected + // behavior) even when the error is `nil`. + // + // To clarify, consider the following scenario where the error + // encountered does NOT belong to the permanent network error category + // (`errkit.ErrKindNetworkPermanent`): + // + // 1. Iteration 1: A timeout error occurs, and the error count for the + // host is incremented. + // 2. Iteration 2: Another timeout error is encountered, leading to + // another increment in the host's error count. + // 3. Iteration 3: A third timeout error happens, which increments the + // error count further. At this point, the host is flagged as + // unresponsive. + // 4. Iteration 4: The host becomes reachable (no error or a transient + // issue resolved). Instead of performing a no-op and leaving the + // host in the cache, the host entry is removed entirely to reset its + // state. + // 5. Iteration 5: A subsequent timeout error occurs after the host was + // removed and re-added to the cache. The error count is reset and + // starts from 1 again. + // + // This removal strategy ensures the cache is updated dynamically to + // reflect the current state of the host without persisting stale or + // irrelevant error counts that could interfere with future error + // handling and tracking logic. + c.Remove(ctx) + + return + } + + cacheKey := c.GetKeyFromContext(ctx, err) + cache, cacheErr := c.failedTargets.GetIFPresent(cacheKey) + if errors.Is(cacheErr, gcache.KeyNotFoundError) { + cache = &cacheItem{errors: atomic.Int32{}} + } + + cache.mu.Lock() + defer cache.mu.Unlock() + + if errkit.IsKind(err, errkit.ErrKindNetworkPermanent) { + cache.isPermanentErr = true + } + + cache.cause = err + cache.errors.Add(1) + + _ = c.failedTargets.Set(cacheKey, cache) } // GetKeyFromContext returns the key for the cache from the context diff --git a/pkg/protocols/common/hosterrorscache/hosterrorscache_test.go b/pkg/protocols/common/hosterrorscache/hosterrorscache_test.go index 9977b968d9..e0046ff412 100644 --- a/pkg/protocols/common/hosterrorscache/hosterrorscache_test.go +++ b/pkg/protocols/common/hosterrorscache/hosterrorscache_test.go @@ -2,7 +2,7 @@ package hosterrorscache import ( "context" - "fmt" + "errors" "sync" "sync/atomic" "testing" @@ -17,28 +17,40 @@ const ( func TestCacheCheck(t *testing.T) { cache := New(3, DefaultMaxHostsCount, nil) + err := errors.New("net/http: timeout awaiting response headers") + + t.Run("increment host error", func(t *testing.T) { + ctx := newCtxArgs(t.Name()) + for i := 1; i < 3; i++ { + cache.MarkFailed(protoType, ctx, err) + got := cache.Check(protoType, ctx) + require.Falsef(t, got, "got %v in iteration %d", got, i) + } + }) - for i := 0; i < 100; i++ { - cache.MarkFailed(protoType, newCtxArgs("test"), fmt.Errorf("could not resolve host")) - got := cache.Check(protoType, newCtxArgs("test")) - if i < 2 { - // till 3 the host is not flagged to skip - require.False(t, got) - } else { - // above 3 it must remain flagged to skip - require.True(t, got) + t.Run("flagged", func(t *testing.T) { + ctx := newCtxArgs(t.Name()) + for i := 1; i <= 3; i++ { + cache.MarkFailed(protoType, ctx, err) } - } - value := cache.Check(protoType, newCtxArgs("test")) - require.Equal(t, true, value, "could not get checked value") + got := cache.Check(protoType, ctx) + require.True(t, got) + }) + + t.Run("mark failed or remove", func(t *testing.T) { + ctx := newCtxArgs(t.Name()) + cache.MarkFailedOrRemove(protoType, ctx, nil) // nil error should remove the host from cache + got := cache.Check(protoType, ctx) + require.False(t, got) + }) } func TestTrackErrors(t *testing.T) { cache := New(3, DefaultMaxHostsCount, []string{"custom error"}) for i := 0; i < 100; i++ { - cache.MarkFailed(protoType, newCtxArgs("custom"), fmt.Errorf("got: nested: custom error")) + cache.MarkFailed(protoType, newCtxArgs("custom"), errors.New("got: nested: custom error")) got := cache.Check(protoType, newCtxArgs("custom")) if i < 2 { // till 3 the host is not flagged to skip @@ -74,6 +86,20 @@ func TestCacheItemDo(t *testing.T) { require.Equal(t, count, 1) } +func TestRemove(t *testing.T) { + cache := New(3, DefaultMaxHostsCount, nil) + ctx := newCtxArgs(t.Name()) + err := errors.New("net/http: timeout awaiting response headers") + + for i := 0; i < 100; i++ { + cache.MarkFailed(protoType, ctx, err) + } + + require.True(t, cache.Check(protoType, ctx)) + cache.Remove(ctx) + require.False(t, cache.Check(protoType, ctx)) +} + func TestCacheMarkFailed(t *testing.T) { cache := New(3, DefaultMaxHostsCount, nil) @@ -90,7 +116,7 @@ func TestCacheMarkFailed(t *testing.T) { for _, test := range tests { normalizedCacheValue := cache.GetKeyFromContext(newCtxArgs(test.host), nil) - cache.MarkFailed(protoType, newCtxArgs(test.host), fmt.Errorf("no address found for host")) + cache.MarkFailed(protoType, newCtxArgs(test.host), errors.New("no address found for host")) failedTarget, err := cache.failedTargets.Get(normalizedCacheValue) require.Nil(t, err) require.NotNil(t, failedTarget) @@ -126,7 +152,7 @@ func TestCacheMarkFailedConcurrent(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - cache.MarkFailed(protoType, newCtxArgs(currentTest.host), fmt.Errorf("could not resolve host")) + cache.MarkFailed(protoType, newCtxArgs(currentTest.host), errors.New("net/http: timeout awaiting response headers")) }() } } @@ -144,6 +170,26 @@ func TestCacheMarkFailedConcurrent(t *testing.T) { } } +func TestCacheCheckConcurrent(t *testing.T) { + cache := New(3, DefaultMaxHostsCount, nil) + ctx := newCtxArgs(t.Name()) + + wg := sync.WaitGroup{} + for i := 1; i <= 100; i++ { + wg.Add(1) + i := i + go func() { + defer wg.Done() + cache.MarkFailed(protoType, ctx, errors.New("no address found for host")) + if i >= 3 { + got := cache.Check(protoType, ctx) + require.True(t, got) + } + }() + } + wg.Wait() +} + func newCtxArgs(value string) *contextargs.Context { ctx := contextargs.NewWithInput(context.TODO(), value) return ctx diff --git a/pkg/protocols/http/request.go b/pkg/protocols/http/request.go index 4ce9e57f55..bf5862d276 100644 --- a/pkg/protocols/http/request.go +++ b/pkg/protocols/http/request.go @@ -149,11 +149,8 @@ func (request *Request) executeRaceRequest(input *contextargs.Context, previous // look for unresponsive hosts and cancel inflight requests as well spmHandler.SetOnResultCallback(func(err error) { - if err == nil { - return - } // marks thsi host as unresponsive if applicable - request.markUnresponsiveAddress(input, err) + request.markHostError(input, err) if request.isUnresponsiveAddress(input) { // stop all inflight requests spmHandler.Cancel() @@ -234,11 +231,8 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV // look for unresponsive hosts and cancel inflight requests as well spmHandler.SetOnResultCallback(func(err error) { - if err == nil { - return - } // marks thsi host as unresponsive if applicable - request.markUnresponsiveAddress(input, err) + request.markHostError(input, err) if request.isUnresponsiveAddress(input) { // stop all inflight requests spmHandler.Cancel() @@ -378,11 +372,8 @@ func (request *Request) executeTurboHTTP(input *contextargs.Context, dynamicValu // look for unresponsive hosts and cancel inflight requests as well spmHandler.SetOnResultCallback(func(err error) { - if err == nil { - return - } // marks thsi host as unresponsive if applicable - request.markUnresponsiveAddress(input, err) + request.markHostError(input, err) if request.isUnresponsiveAddress(input) { // stop all inflight requests spmHandler.Cancel() @@ -551,12 +542,12 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa } if execReqErr != nil { // if applicable mark the host as unresponsive - request.markUnresponsiveAddress(updatedInput, execReqErr) requestErr = errorutil.NewWithErr(execReqErr).Msgf("got err while executing %v", generatedHttpRequest.URL()) request.options.Progress.IncrementFailedRequestsBy(1) } else { request.options.Progress.IncrementRequests() } + request.markHostError(updatedInput, execReqErr) // If this was a match, and we want to stop at first match, skip all further requests. shouldStopAtFirstMatch := generatedHttpRequest.original.options.Options.StopAtFirstMatch || generatedHttpRequest.original.options.StopAtFirstMatch || request.StopAtFirstMatch @@ -1199,13 +1190,10 @@ func (request *Request) newContext(input *contextargs.Context) context.Context { return input.Context() } -// markUnresponsiveAddress checks if the error is a unreponsive host error and marks it -func (request *Request) markUnresponsiveAddress(input *contextargs.Context, err error) { - if err == nil { - return - } +// markHostError checks if the error is a unreponsive host error and marks it +func (request *Request) markHostError(input *contextargs.Context, err error) { if request.options.HostErrorsCache != nil { - request.options.HostErrorsCache.MarkFailed(request.options.ProtocolType.String(), input, err) + request.options.HostErrorsCache.MarkFailedOrRemove(request.options.ProtocolType.String(), input, err) } } diff --git a/pkg/protocols/http/request_fuzz.go b/pkg/protocols/http/request_fuzz.go index da68300f87..7175a7514b 100644 --- a/pkg/protocols/http/request_fuzz.go +++ b/pkg/protocols/http/request_fuzz.go @@ -223,11 +223,11 @@ func (request *Request) executeGeneratedFuzzingRequest(gr fuzz.GeneratedRequest, return false } if requestErr != nil { - if request.options.HostErrorsCache != nil { - request.options.HostErrorsCache.MarkFailed(request.options.ProtocolType.String(), input, requestErr) - } gologger.Verbose().Msgf("[%s] Error occurred in request: %s\n", request.options.TemplateID, requestErr) } + if request.options.HostErrorsCache != nil { + request.options.HostErrorsCache.MarkFailedOrRemove(request.options.ProtocolType.String(), input, requestErr) + } request.options.Progress.IncrementRequests() // If this was a match, and we want to stop at first match, skip all further requests. diff --git a/pkg/protocols/network/request.go b/pkg/protocols/network/request.go index 3579acd3b8..f7b11fbb53 100644 --- a/pkg/protocols/network/request.go +++ b/pkg/protocols/network/request.go @@ -291,9 +291,9 @@ func (request *Request) executeRequestWithPayloads(variables map[string]interfac } else { conn, err = request.dialer.Dial(input.Context(), "tcp", actualAddress) } + // adds it to unresponsive address list if applicable + request.markHostError(updatedTarget, err) if err != nil { - // adds it to unresponsive address list if applicable - request.markUnresponsiveAddress(updatedTarget, err) request.options.Output.Request(request.options.TemplatePath, address, request.Type().String(), err) request.options.Progress.IncrementFailedRequestsBy(1) return errors.Wrap(err, "could not connect to server") @@ -524,13 +524,10 @@ func ConnReadNWithTimeout(conn net.Conn, n int64, timeout time.Duration) ([]byte return b[:count], nil } -// markUnresponsiveAddress checks if the error is a unreponsive host error and marks it -func (request *Request) markUnresponsiveAddress(input *contextargs.Context, err error) { - if err == nil { - return - } +// markHostError checks if the error is a unreponsive host error and marks it +func (request *Request) markHostError(input *contextargs.Context, err error) { if request.options.HostErrorsCache != nil { - request.options.HostErrorsCache.MarkFailed(request.options.ProtocolType.String(), input, err) + request.options.HostErrorsCache.MarkFailedOrRemove(request.options.ProtocolType.String(), input, err) } } diff --git a/pkg/templates/cluster.go b/pkg/templates/cluster.go index 63b065d346..03ad79c605 100644 --- a/pkg/templates/cluster.go +++ b/pkg/templates/cluster.go @@ -273,8 +273,8 @@ func (e *ClusterExecuter) Execute(ctx *scan.ScanContext) (bool, error) { } } }) - if err != nil && e.options.HostErrorsCache != nil { - e.options.HostErrorsCache.MarkFailed(e.options.ProtocolType.String(), ctx.Input, err) + if e.options.HostErrorsCache != nil { + e.options.HostErrorsCache.MarkFailedOrRemove(e.options.ProtocolType.String(), ctx.Input, err) } return results, err } @@ -309,8 +309,8 @@ func (e *ClusterExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.R ctx.LogError(err) } - if err != nil && e.options.HostErrorsCache != nil { - e.options.HostErrorsCache.MarkFailed(e.options.ProtocolType.String(), ctx.Input, err) + if e.options.HostErrorsCache != nil { + e.options.HostErrorsCache.MarkFailedOrRemove(e.options.ProtocolType.String(), ctx.Input, err) } return scanCtx.GenerateResult(), err } diff --git a/pkg/tmplexec/generic/exec.go b/pkg/tmplexec/generic/exec.go index c8303f70d9..c017810e75 100644 --- a/pkg/tmplexec/generic/exec.go +++ b/pkg/tmplexec/generic/exec.go @@ -84,11 +84,11 @@ func (g *Generic) ExecuteWithResults(ctx *scan.ScanContext) error { }) if err != nil { ctx.LogError(err) - if g.options.HostErrorsCache != nil { - g.options.HostErrorsCache.MarkFailed(g.options.ProtocolType.String(), ctx.Input, err) - } gologger.Warning().Msgf("[%s] Could not execute request for %s: %s\n", g.options.TemplateID, ctx.Input.MetaInput.PrettyPrint(), err) } + if g.options.HostErrorsCache != nil { + g.options.HostErrorsCache.MarkFailedOrRemove(g.options.ProtocolType.String(), ctx.Input, err) + } // If a match was found and stop at first match is set, break out of the loop and return if g.results.Load() && (g.options.StopAtFirstMatch || g.options.Options.StopAtFirstMatch) { break