From e102a91536936c6a17c1fd30c8b006ba7751558c Mon Sep 17 00:00:00 2001 From: Max Riveiro Date: Fri, 13 Jan 2017 00:17:45 +0300 Subject: [PATCH] Handle TCP, UDP and wildcard addrs Signed-off-by: Max Riveiro --- tcp.go | 34 +++++++++++++++++++++----------- tcp_test.go | 57 ++++++++++++++++++++++------------------------------- udp.go | 34 +++++++++++++++++++++----------- udp_test.go | 24 +++++++++++++++++++--- 4 files changed, 89 insertions(+), 60 deletions(-) diff --git a/tcp.go b/tcp.go index 9ab7cd0..4960d62 100644 --- a/tcp.go +++ b/tcp.go @@ -19,11 +19,7 @@ var ( ) func getTCPSockaddr(proto, addr string) (sa syscall.Sockaddr, soType int, err error) { - var ( - addr4 [4]byte - addr6 [16]byte - tcp *net.TCPAddr - ) + var tcp *net.TCPAddr tcp, err = net.ResolveTCPAddr(proto, addr) if err != nil && tcp.IP != nil { @@ -36,24 +32,33 @@ func getTCPSockaddr(proto, addr string) (sa syscall.Sockaddr, soType int, err er } switch tcpVersion { + case "tcp": + return &syscall.SockaddrInet4{Port: tcp.Port}, syscall.AF_INET, nil case "tcp4": - copy(addr4[:], tcp.IP[12:16]) // copy last 4 bytes of slice to array + sa := &syscall.SockaddrInet4{Port: tcp.Port} - return &syscall.SockaddrInet4{Port: tcp.Port, Addr: addr4}, syscall.AF_INET, nil + if tcp.IP != nil { + copy(sa.Addr[:], tcp.IP[12:16]) // copy last 4 bytes of slice to array + } + return sa, syscall.AF_INET, nil case "tcp6": - copy(addr6[:], tcp.IP) // copy all bytes of slice to array + sa := &syscall.SockaddrInet6{Port: tcp.Port} + + if tcp.IP != nil { + copy(sa.Addr[:], tcp.IP) // copy all bytes of slice to array + } - return &syscall.SockaddrInet6{Port: tcp.Port, Addr: addr6}, syscall.AF_INET6, nil + return sa, syscall.AF_INET6, nil } return nil, -1, errUnsupportedProtocol } func determineTCPProto(proto string, ip *net.TCPAddr) (string, error) { - // If the protocol is set to "tcp", we determine the actual protocol - // version from the size of the IP address. Otherwise, we use the - // protcol given to us by the caller. + // If the protocol is set to "tcp", we try to determine the actual protocol + // version from the size of the resolved IP address. Otherwise, we simple use + // the protcol given to us by the caller. if ip.IP.To4() != nil { return "tcp4", nil @@ -63,6 +68,11 @@ func determineTCPProto(proto string, ip *net.TCPAddr) (string, error) { return "tcp6", nil } + switch proto { + case "tcp", "tcp4", "tcp6": + return proto, nil + } + return "", errUnsupportedTCPProtocol } diff --git a/tcp_test.go b/tcp_test.go index 6521ca1..79f417f 100644 --- a/tcp_test.go +++ b/tcp_test.go @@ -17,15 +17,13 @@ import ( ) const ( - httpServerOneResponse = "1" - httpServerTwoResponse = "2" - httpServerThreeResponse = "3" + httpServerOneResponse = "1" + httpServerTwoResponse = "2" ) var ( - httpServerOne = NewHTTPServer(httpServerOneResponse) - httpServerTwo = NewHTTPServer(httpServerTwoResponse) - httpServerThree = NewHTTPServer(httpServerThreeResponse) + httpServerOne = NewHTTPServer(httpServerOneResponse) + httpServerTwo = NewHTTPServer(httpServerTwoResponse) ) func NewHTTPServer(resp string) *httptest.Server { @@ -51,6 +49,24 @@ func TestNewReusablePortListener(t *testing.T) { t.Error(err) } defer listenerThree.Close() + + listenerFour, err := NewReusablePortListener("tcp6", ":10081") + if err != nil { + t.Error(err) + } + defer listenerFour.Close() + + listenerFive, err := NewReusablePortListener("tcp4", ":10081") + if err != nil { + t.Error(err) + } + defer listenerFive.Close() + + listenerSix, err := NewReusablePortListener("tcp", ":10081") + if err != nil { + t.Error(err) + } + defer listenerSix.Close() } func TestNewReusablePortServers(t *testing.T) { @@ -60,21 +76,14 @@ func TestNewReusablePortServers(t *testing.T) { } defer listenerOne.Close() - listenerTwo, err := NewReusablePortListener("tcp", "127.0.0.1:10081") + listenerTwo, err := NewReusablePortListener("tcp6", ":10081") if err != nil { t.Error(err) } defer listenerTwo.Close() - // listenerThree, err := NewReusablePortListener("tcp6", "[::1]:10081") - // if err != nil { - // t.Error(err) - // } - // defer listenerThree.Close() - httpServerOne.Listener = listenerOne httpServerTwo.Listener = listenerTwo - // httpServerThree.Listener = listenerThree httpServerOne.Start() httpServerTwo.Start() @@ -123,24 +132,6 @@ func TestNewReusablePortServers(t *testing.T) { t.Errorf("Expected %#v, got %#v.", httpServerOneResponse, string(body3)) } - httpServerThree.Start() - - // Server Three — First Response - resp4, err := http.Get(httpServerThree.URL) - if err != nil { - t.Error(err) - } - body4, err := ioutil.ReadAll(resp4.Body) - resp1.Body.Close() - if err != nil { - t.Error(err) - } - if string(body4) != httpServerThreeResponse { - t.Errorf("Expected %#v, got %#v.", httpServerThreeResponse, string(body4)) - } - - httpServerThree.Close() - // Server One — Third Response resp5, err := http.Get(httpServerOne.URL) if err != nil { @@ -160,7 +151,7 @@ func TestNewReusablePortServers(t *testing.T) { func BenchmarkNewReusablePortListener(b *testing.B) { for i := 0; i < b.N; i++ { - listener, err := NewReusablePortListener("tcp4", "localhost:10081") + listener, err := NewReusablePortListener("tcp", ":10081") if err != nil { b.Error(err) diff --git a/udp.go b/udp.go index 27590c6..d40672c 100644 --- a/udp.go +++ b/udp.go @@ -16,11 +16,7 @@ import ( var errUnsupportedUDPProtocol = errors.New("only udp, udp4, udp6 are supported") func getUDPSockaddr(proto, addr string) (sa syscall.Sockaddr, soType int, err error) { - var ( - addr4 [4]byte - addr6 [16]byte - udp *net.UDPAddr - ) + var udp *net.UDPAddr udp, err = net.ResolveUDPAddr(proto, addr) if err != nil && udp.IP != nil { @@ -33,24 +29,33 @@ func getUDPSockaddr(proto, addr string) (sa syscall.Sockaddr, soType int, err er } switch udpVersion { + case "udp": + return &syscall.SockaddrInet4{Port: udp.Port}, syscall.AF_INET, nil case "udp4": - copy(addr4[:], udp.IP[12:16]) // copy last 4 bytes of slice to array + sa := &syscall.SockaddrInet4{Port: udp.Port} - return &syscall.SockaddrInet4{Port: udp.Port, Addr: addr4}, syscall.AF_INET, nil + if udp.IP != nil { + copy(sa.Addr[:], udp.IP[12:16]) // copy last 4 bytes of slice to array + } + return sa, syscall.AF_INET, nil case "udp6": - copy(addr6[:], udp.IP) // copy all bytes of slice to array + sa := &syscall.SockaddrInet6{Port: udp.Port} + + if udp.IP != nil { + copy(sa.Addr[:], udp.IP) // copy all bytes of slice to array + } - return &syscall.SockaddrInet6{Port: udp.Port, Addr: addr6}, syscall.AF_INET6, nil + return sa, syscall.AF_INET6, nil } return nil, -1, errUnsupportedProtocol } func determineUDPProto(proto string, ip *net.UDPAddr) (string, error) { - // If the protocol is set to "udp", we determine the actual protocol - // version from the size of the IP address. Otherwise, we use the - // protcol given to us by the caller. + // If the protocol is set to "udp", we try to determine the actual protocol + // version from the size of the resolved IP address. Otherwise, we simple use + // the protcol given to us by the caller. if ip.IP.To4() != nil { return "udp4", nil @@ -60,6 +65,11 @@ func determineUDPProto(proto string, ip *net.UDPAddr) (string, error) { return "udp6", nil } + switch proto { + case "udp", "udp4", "udp6": + return proto, nil + } + return "", errUnsupportedUDPProtocol } diff --git a/udp_test.go b/udp_test.go index 5cc3510..007e324 100644 --- a/udp_test.go +++ b/udp_test.go @@ -9,23 +9,41 @@ package reuseport import "testing" func TestNewReusablePortUDPListener(t *testing.T) { - listenerOne, err := NewReusablePortPacketConn("udp4", "localhost:10081") + listenerOne, err := NewReusablePortPacketConn("udp4", "localhost:10082") if err != nil { t.Error(err) } defer listenerOne.Close() - listenerTwo, err := NewReusablePortPacketConn("udp", "127.0.0.1:10081") + listenerTwo, err := NewReusablePortPacketConn("udp", "127.0.0.1:10082") if err != nil { t.Error(err) } defer listenerTwo.Close() - listenerThree, err := NewReusablePortPacketConn("udp6", "[::1]:10081") + listenerThree, err := NewReusablePortPacketConn("udp6", "[::1]:10082") if err != nil { t.Error(err) } defer listenerThree.Close() + + listenerFour, err := NewReusablePortListener("udp6", ":10081") + if err != nil { + t.Error(err) + } + defer listenerFour.Close() + + listenerFive, err := NewReusablePortListener("udp4", ":10081") + if err != nil { + t.Error(err) + } + defer listenerFive.Close() + + listenerSix, err := NewReusablePortListener("udp", ":10081") + if err != nil { + t.Error(err) + } + defer listenerSix.Close() } func BenchmarkNewReusableUDPPortListener(b *testing.B) {