From 647dacca770d33913c1b0695b3b71ba261bd341d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Senart?= Date: Fri, 2 Jul 2021 19:13:21 +0200 Subject: [PATCH] lib,cmd: add -connect-to flag Closes #692, #691, #575 Co-authored-by: dank@kegel.com Co-authored-by: Antonio M. Amaya --- README.md | 5 +++- attack.go | 3 +++ flags.go | 52 +++++++++++++++++++++++++++++++++++++ lib/attack.go | 42 ++++++++++++++++++++++++++++++ lib/attack_test.go | 64 ++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 165 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4a952147..81e5a8dd 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,10 @@ attack command: TLS client PEM encoded certificate file -chunked Send body with chunked transfer encoding + -connect-to value + A mapping of (ip|host):port to use instead of a target URL's (ip|host):port. Can be repeated multiple times. + Identical src:port with different dst:port will round-robin over the different dst:port pairs. + Example: google.com:80:localhost:6060 -connections int Max open idle connections per target host (default 10000) -dns-ttl value @@ -178,7 +182,6 @@ examples: vegeta report -type=json results.bin > metrics.json cat results.bin | vegeta plot > plot.html cat results.bin | vegeta report -type="hist[0,100ms,200ms,300ms]" - ``` #### `-cpus` diff --git a/attack.go b/attack.go index f83c2cba..061285b5 100644 --- a/attack.go +++ b/attack.go @@ -62,6 +62,7 @@ func attackCmd() command { fs.StringVar(&opts.promAddr, "prometheus-addr", "", "Prometheus exporter listen address [empty = disabled]. Example: 0.0.0.0:8880") fs.Var(&dnsTTLFlag{&opts.dnsTTL}, "dns-ttl", "Cache DNS lookups for the given duration [-1 = disabled, 0 = forever]") fs.BoolVar(&opts.sessionTickets, "session-tickets", false, "Enable TLS session resumption using session tickets") + fs.Var(&connectToFlag{&opts.connectTo}, "connect-to", "A mapping of (ip|host):port to use instead of a target URL's (ip|host):port. Can be repeated multiple times.\nIdentical src:port with different dst:port will round-robin over the different dst:port pairs.\nExample: google.com:80:localhost:6060") systemSpecificFlags(fs, opts) return command{fs, func(args []string) error { @@ -108,6 +109,7 @@ type attackOpts struct { promAddr string dnsTTL time.Duration sessionTickets bool + connectTo map[string][]string } // attack validates the attack arguments, sets up the @@ -218,6 +220,7 @@ func attack(opts *attackOpts) (err error) { vegeta.ProxyHeader(proxyHdr), vegeta.ChunkedBody(opts.chunked), vegeta.DNSCaching(opts.dnsTTL), + vegeta.ConnectTo(opts.connectTo), vegeta.SessionTickets(opts.sessionTickets), ) diff --git a/flags.go b/flags.go index 7ac8fcb1..cbd47bbc 100644 --- a/flags.go +++ b/flags.go @@ -6,6 +6,7 @@ import ( "math" "net" "net/http" + "sort" "strconv" "strings" "time" @@ -153,3 +154,54 @@ func (f *dnsTTLFlag) String() string { } return f.ttl.String() } + +const connectToFormat = "src:port:dst:port" + +type connectToFlag struct { + addrMap *map[string][]string +} + +func (c *connectToFlag) String() string { + if c.addrMap == nil { + return "" + } + + addrMappings := make([]string, 0, len(*c.addrMap)) + for k, v := range *c.addrMap { + addrMappings = append(addrMappings, k+":"+strings.Join(v, ",")) + } + + sort.Strings(addrMappings) + return strings.Join(addrMappings, ";") +} + +func (c *connectToFlag) Set(s string) error { + if c.addrMap == nil { + return nil + } + + if *c.addrMap == nil { + *c.addrMap = make(map[string][]string) + } + + parts := strings.Split(s, ":") + if len(parts) != 4 { + return fmt.Errorf("invalid -connect-to %q, expected format: %s", s, connectToFormat) + } + srcAddr := parts[0] + ":" + parts[1] + dstAddr := parts[2] + ":" + parts[3] + + // Parse source address + if _, _, err := net.SplitHostPort(srcAddr); err != nil { + return fmt.Errorf("invalid source address expression [%s], expected address:port", srcAddr) + } + + // Parse destination address + if _, _, err := net.SplitHostPort(dstAddr); err != nil { + return fmt.Errorf("invalid destination address expression [%s], expected address:port", dstAddr) + } + + (*c.addrMap)[srcAddr] = append((*c.addrMap)[srcAddr], dstAddr) + + return nil +} diff --git a/lib/attack.go b/lib/attack.go index 0e43e91f..d2f1cf5a 100644 --- a/lib/attack.go +++ b/lib/attack.go @@ -28,6 +28,9 @@ type Attacker struct { maxWorkers uint64 maxBody int64 redirects int + seqmu sync.Mutex + seq uint64 + began time.Time chunked bool } @@ -272,6 +275,45 @@ func ProxyHeader(h http.Header) func(*Attacker) { } } +// ConnectTo returns a functional option which makes the attacker use the +// passed in map to translate target addr:port pairs. When used with DNSCaching, +// it must be used after it. +func ConnectTo(addrMap map[string][]string) func(*Attacker) { + return func(a *Attacker) { + if len(addrMap) == 0 { + return + } + + tr, ok := a.client.Transport.(*http.Transport) + if !ok { + return + } + + dial := tr.DialContext + if dial == nil { + dial = a.dialer.DialContext + } + + type roundRobin struct { + addrs []string + n int + } + + connectTo := make(map[string]*roundRobin, len(addrMap)) + for k, v := range addrMap { + connectTo[k] = &roundRobin{addrs: v} + } + + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + if cm, ok := connectTo[addr]; ok { + cm.n = (cm.n + 1) % len(cm.addrs) + addr = cm.addrs[cm.n] + } + return dial(ctx, network, addr) + } + } +} + // DNSCaching returns a functional option that enables DNS caching for // the given ttl. When ttl is zero cached entries will never expire. // When ttl is non-zero, this will start a refresh go-routine that updates diff --git a/lib/attack_test.go b/lib/attack_test.go index 4ae0145f..e847050b 100644 --- a/lib/attack_test.go +++ b/lib/attack_test.go @@ -15,6 +15,7 @@ import ( "reflect" "strconv" "strings" + "sync" "testing" "time" @@ -498,3 +499,66 @@ func TestFirstOfEachIPFamily(t *testing.T) { }) } } + +func TestAttackConnectTo(t *testing.T) { + t.Parallel() + var mu sync.Mutex + hits := make(map[string]int) + srvs := make(map[string]int) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + hits[r.Host]++ + mu.Unlock() + }) + + addrs := make([]string, 3) + for i := range addrs { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + addrs[i] = ln.Addr().String() + + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + srvs[ln.Addr().String()]++ + mu.Unlock() + handler.ServeHTTP(w, r) + })) + + srv.Listener = ln + srv.Start() + t.Cleanup(srv.Close) + } + + tr := NewStaticTargeter( + Target{Method: "GET", URL: "http://sapo.pt:80"}, + Target{Method: "GET", URL: "http://sapo.pt:80"}, + Target{Method: "GET", URL: "http://sapo.pt:80"}, + Target{Method: "GET", URL: "http://" + addrs[0]}, + ) + + atk := NewAttacker( + KeepAlive(false), + ConnectTo(map[string][]string{"sapo.pt:80": addrs}), + ) + + a := &attack{name: "TEST", began: time.Now()} + for i := 0; i < 4; i++ { + resp := atk.hit(tr, a) + if resp.Error != "" { + t.Fatal(resp.Error) + } + } + + want := map[string]int{"sapo.pt:80": 3, addrs[0]: 1} + if diff := cmp.Diff(want, hits); diff != "" { + t.Errorf("unexpected hits (-want +got):\n%s", diff) + } + + want = map[string]int{addrs[0]: 2, addrs[1]: 1, addrs[2]: 1} + if diff := cmp.Diff(want, srvs); diff != "" { + t.Errorf("unexpected hits (-want +got):\n%s", diff) + } +}