Skip to content

Commit

Permalink
Handle TCP, UDP and wildcard addrs
Browse files Browse the repository at this point in the history
Signed-off-by: Max Riveiro <[email protected]>
  • Loading branch information
kavu committed Jan 14, 2017
1 parent 4b554a6 commit e102a91
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 60 deletions.
34 changes: 22 additions & 12 deletions tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@ var (
)

func getTCPSockaddr(proto, addr string) (sa syscall.Sockaddr, soType int, err error) {
var (
addr4 [4]byte
addr6 [16]byte
tcp *net.TCPAddr
)
var tcp *net.TCPAddr

tcp, err = net.ResolveTCPAddr(proto, addr)
if err != nil && tcp.IP != nil {
Expand All @@ -36,24 +32,33 @@ func getTCPSockaddr(proto, addr string) (sa syscall.Sockaddr, soType int, err er
}

switch tcpVersion {
case "tcp":
return &syscall.SockaddrInet4{Port: tcp.Port}, syscall.AF_INET, nil
case "tcp4":
copy(addr4[:], tcp.IP[12:16]) // copy last 4 bytes of slice to array
sa := &syscall.SockaddrInet4{Port: tcp.Port}

return &syscall.SockaddrInet4{Port: tcp.Port, Addr: addr4}, syscall.AF_INET, nil
if tcp.IP != nil {
copy(sa.Addr[:], tcp.IP[12:16]) // copy last 4 bytes of slice to array
}

return sa, syscall.AF_INET, nil
case "tcp6":
copy(addr6[:], tcp.IP) // copy all bytes of slice to array
sa := &syscall.SockaddrInet6{Port: tcp.Port}

if tcp.IP != nil {
copy(sa.Addr[:], tcp.IP) // copy all bytes of slice to array
}

return &syscall.SockaddrInet6{Port: tcp.Port, Addr: addr6}, syscall.AF_INET6, nil
return sa, syscall.AF_INET6, nil
}

return nil, -1, errUnsupportedProtocol
}

func determineTCPProto(proto string, ip *net.TCPAddr) (string, error) {
// If the protocol is set to "tcp", we determine the actual protocol
// version from the size of the IP address. Otherwise, we use the
// protcol given to us by the caller.
// If the protocol is set to "tcp", we try to determine the actual protocol
// version from the size of the resolved IP address. Otherwise, we simple use
// the protcol given to us by the caller.

if ip.IP.To4() != nil {
return "tcp4", nil
Expand All @@ -63,6 +68,11 @@ func determineTCPProto(proto string, ip *net.TCPAddr) (string, error) {
return "tcp6", nil
}

switch proto {
case "tcp", "tcp4", "tcp6":
return proto, nil
}

return "", errUnsupportedTCPProtocol
}

Expand Down
57 changes: 24 additions & 33 deletions tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@ import (
)

const (
httpServerOneResponse = "1"
httpServerTwoResponse = "2"
httpServerThreeResponse = "3"
httpServerOneResponse = "1"
httpServerTwoResponse = "2"
)

var (
httpServerOne = NewHTTPServer(httpServerOneResponse)
httpServerTwo = NewHTTPServer(httpServerTwoResponse)
httpServerThree = NewHTTPServer(httpServerThreeResponse)
httpServerOne = NewHTTPServer(httpServerOneResponse)
httpServerTwo = NewHTTPServer(httpServerTwoResponse)
)

func NewHTTPServer(resp string) *httptest.Server {
Expand All @@ -51,6 +49,24 @@ func TestNewReusablePortListener(t *testing.T) {
t.Error(err)
}
defer listenerThree.Close()

listenerFour, err := NewReusablePortListener("tcp6", ":10081")
if err != nil {
t.Error(err)
}
defer listenerFour.Close()

listenerFive, err := NewReusablePortListener("tcp4", ":10081")
if err != nil {
t.Error(err)
}
defer listenerFive.Close()

listenerSix, err := NewReusablePortListener("tcp", ":10081")
if err != nil {
t.Error(err)
}
defer listenerSix.Close()
}

func TestNewReusablePortServers(t *testing.T) {
Expand All @@ -60,21 +76,14 @@ func TestNewReusablePortServers(t *testing.T) {
}
defer listenerOne.Close()

listenerTwo, err := NewReusablePortListener("tcp", "127.0.0.1:10081")
listenerTwo, err := NewReusablePortListener("tcp6", ":10081")
if err != nil {
t.Error(err)
}
defer listenerTwo.Close()

// listenerThree, err := NewReusablePortListener("tcp6", "[::1]:10081")
// if err != nil {
// t.Error(err)
// }
// defer listenerThree.Close()

httpServerOne.Listener = listenerOne
httpServerTwo.Listener = listenerTwo
// httpServerThree.Listener = listenerThree

httpServerOne.Start()
httpServerTwo.Start()
Expand Down Expand Up @@ -123,24 +132,6 @@ func TestNewReusablePortServers(t *testing.T) {
t.Errorf("Expected %#v, got %#v.", httpServerOneResponse, string(body3))
}

httpServerThree.Start()

// Server Three — First Response
resp4, err := http.Get(httpServerThree.URL)
if err != nil {
t.Error(err)
}
body4, err := ioutil.ReadAll(resp4.Body)
resp1.Body.Close()
if err != nil {
t.Error(err)
}
if string(body4) != httpServerThreeResponse {
t.Errorf("Expected %#v, got %#v.", httpServerThreeResponse, string(body4))
}

httpServerThree.Close()

// Server One — Third Response
resp5, err := http.Get(httpServerOne.URL)
if err != nil {
Expand All @@ -160,7 +151,7 @@ func TestNewReusablePortServers(t *testing.T) {

func BenchmarkNewReusablePortListener(b *testing.B) {
for i := 0; i < b.N; i++ {
listener, err := NewReusablePortListener("tcp4", "localhost:10081")
listener, err := NewReusablePortListener("tcp", ":10081")

if err != nil {
b.Error(err)
Expand Down
34 changes: 22 additions & 12 deletions udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@ import (
var errUnsupportedUDPProtocol = errors.New("only udp, udp4, udp6 are supported")

func getUDPSockaddr(proto, addr string) (sa syscall.Sockaddr, soType int, err error) {
var (
addr4 [4]byte
addr6 [16]byte
udp *net.UDPAddr
)
var udp *net.UDPAddr

udp, err = net.ResolveUDPAddr(proto, addr)
if err != nil && udp.IP != nil {
Expand All @@ -33,24 +29,33 @@ func getUDPSockaddr(proto, addr string) (sa syscall.Sockaddr, soType int, err er
}

switch udpVersion {
case "udp":
return &syscall.SockaddrInet4{Port: udp.Port}, syscall.AF_INET, nil
case "udp4":
copy(addr4[:], udp.IP[12:16]) // copy last 4 bytes of slice to array
sa := &syscall.SockaddrInet4{Port: udp.Port}

return &syscall.SockaddrInet4{Port: udp.Port, Addr: addr4}, syscall.AF_INET, nil
if udp.IP != nil {
copy(sa.Addr[:], udp.IP[12:16]) // copy last 4 bytes of slice to array
}

return sa, syscall.AF_INET, nil
case "udp6":
copy(addr6[:], udp.IP) // copy all bytes of slice to array
sa := &syscall.SockaddrInet6{Port: udp.Port}

if udp.IP != nil {
copy(sa.Addr[:], udp.IP) // copy all bytes of slice to array
}

return &syscall.SockaddrInet6{Port: udp.Port, Addr: addr6}, syscall.AF_INET6, nil
return sa, syscall.AF_INET6, nil
}

return nil, -1, errUnsupportedProtocol
}

func determineUDPProto(proto string, ip *net.UDPAddr) (string, error) {
// If the protocol is set to "udp", we determine the actual protocol
// version from the size of the IP address. Otherwise, we use the
// protcol given to us by the caller.
// If the protocol is set to "udp", we try to determine the actual protocol
// version from the size of the resolved IP address. Otherwise, we simple use
// the protcol given to us by the caller.

if ip.IP.To4() != nil {
return "udp4", nil
Expand All @@ -60,6 +65,11 @@ func determineUDPProto(proto string, ip *net.UDPAddr) (string, error) {
return "udp6", nil
}

switch proto {
case "udp", "udp4", "udp6":
return proto, nil
}

return "", errUnsupportedUDPProtocol
}

Expand Down
24 changes: 21 additions & 3 deletions udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,41 @@ package reuseport
import "testing"

func TestNewReusablePortUDPListener(t *testing.T) {
listenerOne, err := NewReusablePortPacketConn("udp4", "localhost:10081")
listenerOne, err := NewReusablePortPacketConn("udp4", "localhost:10082")
if err != nil {
t.Error(err)
}
defer listenerOne.Close()

listenerTwo, err := NewReusablePortPacketConn("udp", "127.0.0.1:10081")
listenerTwo, err := NewReusablePortPacketConn("udp", "127.0.0.1:10082")
if err != nil {
t.Error(err)
}
defer listenerTwo.Close()

listenerThree, err := NewReusablePortPacketConn("udp6", "[::1]:10081")
listenerThree, err := NewReusablePortPacketConn("udp6", "[::1]:10082")
if err != nil {
t.Error(err)
}
defer listenerThree.Close()

listenerFour, err := NewReusablePortListener("udp6", ":10081")
if err != nil {
t.Error(err)
}
defer listenerFour.Close()

listenerFive, err := NewReusablePortListener("udp4", ":10081")
if err != nil {
t.Error(err)
}
defer listenerFive.Close()

listenerSix, err := NewReusablePortListener("udp", ":10081")
if err != nil {
t.Error(err)
}
defer listenerSix.Close()
}

func BenchmarkNewReusableUDPPortListener(b *testing.B) {
Expand Down

0 comments on commit e102a91

Please sign in to comment.