Skip to content

Commit

Permalink
Support for factory method for outbound connections (#759)
Browse files Browse the repository at this point in the history
New optional factory method on ChannelOptions to allow passing in custom dialer
for outbound connections. This could be used for things like TLS handshake
considering tchannel already has support for custom TLS listener.
  • Loading branch information
samarabbas authored and prashantv committed Sep 30, 2019
1 parent bff57bb commit 74b0fff
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
16 changes: 15 additions & 1 deletion channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ type ChannelOptions struct {
// Handler is an alternate handler for all inbound requests, overriding the
// default handler that delegates to a subchannel.
Handler Handler

// Dialer is optional factory method which can be used for overriding
// outbound connections for things like TLS handshake
Dialer func(ctx context.Context, network, hostPort string) (net.Conn, error)
}

// ChannelState is the state of a channel.
Expand Down Expand Up @@ -158,6 +162,7 @@ type Channel struct {
internalHandlers *handlerMap
handler Handler
onPeerStatusChanged func(*Peer)
dialer func(ctx context.Context, hostPort string) (net.Conn, error)
closed chan struct{}

// mutable contains all the members of Channel which are mutable.
Expand Down Expand Up @@ -244,6 +249,14 @@ func NewChannel(serviceName string, opts *ChannelOptions) (*Channel, error) {
return nil, err
}

// Default to dialContext if dialer is not passed in as an option
dialCtx := dialContext
if opts.Dialer != nil {
dialCtx = func (ctx context.Context, hostPort string) (net.Conn, error) {
return opts.Dialer(ctx, "tcp", hostPort)
}
}

ch := &Channel{
channelConnectionCommon: channelConnectionCommon{
log: logger,
Expand All @@ -259,6 +272,7 @@ func NewChannel(serviceName string, opts *ChannelOptions) (*Channel, error) {
relayHost: opts.RelayHost,
relayMaxTimeout: validateRelayMaxTimeout(opts.RelayMaxTimeout, logger),
relayTimerVerify: opts.RelayTimerVerification,
dialer: dialCtx,
closed: make(chan struct{}),
}
ch.peers = newRootPeerList(ch, opts.OnPeerStatusChanged).newChild()
Expand Down Expand Up @@ -563,7 +577,7 @@ func (ch *Channel) Connect(ctx context.Context, hostPort string) (*Connection, e
}

timeout := getTimeout(ctx)
tcpConn, err := dialContext(ctx, hostPort)
tcpConn, err := ch.dialer(ctx, hostPort)
if err != nil {
if ne, ok := err.(net.Error); ok && ne.Timeout() {
ch.log.WithFields(
Expand Down
24 changes: 24 additions & 0 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1322,3 +1322,27 @@ func TestInvalidTransportHeaders(t *testing.T) {
})
}
}

func TestCustomDialer(t *testing.T) {
sopts := testutils.NewOpts()
testutils.WithTestServer(t, sopts, func(t testing.TB, ts *testutils.TestServer) {
server := ts.Server()
testutils.RegisterEcho(server, nil)
customDialerCalledCount := 0

copts := testutils.NewOpts().SetDialer(func(ctx context.Context, network, hostPort string) (net.Conn, error) {
customDialerCalledCount++
d := net.Dialer{}
return d.DialContext(ctx, network, hostPort)
})

// Induce the creation of a connection from client to server.
client := ts.NewClient(copts)
testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName())
assert.Equal(t, 1, customDialerCalledCount, "custom dialer used for establishing connection")

// Re-use
testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName())
assert.Equal(t, 1, customDialerCalledCount, "custom dialer used for establishing connection")
})
}
8 changes: 8 additions & 0 deletions testutils/channel_opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ package testutils

import (
"flag"
"net"
"testing"
"time"

"github.com/uber/tchannel-go"
"github.com/uber/tchannel-go/tos"

"go.uber.org/atomic"
"golang.org/x/net/context"
)

var connectionLog = flag.Bool("connectionLog", false, "Enables connection logging in tests")
Expand Down Expand Up @@ -233,6 +235,12 @@ func (o *ChannelOpts) SetIdleCheckInterval(d time.Duration) *ChannelOpts {
return o
}

// SetDialer sets the dialer used for outbound connections
func (o *ChannelOpts) SetDialer(f func(context.Context, string, string) (net.Conn, error)) *ChannelOpts {
o.ChannelOptions.Dialer = f
return o
}

func defaultString(v string, defaultValue string) string {
if v == "" {
return defaultValue
Expand Down

0 comments on commit 74b0fff

Please sign in to comment.