Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making LookupClient disposable and improved TCP client pool #220

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/DnsClient/DnsMessageHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ internal enum DnsMessageHandleType
TCP
}

internal abstract class DnsMessageHandler
internal abstract class DnsMessageHandler : IDisposable
{
public abstract DnsMessageHandleType Type { get; }

Expand Down Expand Up @@ -170,5 +170,16 @@ public virtual DnsResponseMessage GetResponseMessage(ArraySegment<byte> response

return response;
}

protected virtual void Dispose(bool disposing)
{
// Nothing to do in base class.
}

public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
}
}
149 changes: 116 additions & 33 deletions src/DnsClient/DnsTcpMessageHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,22 @@ namespace DnsClient
{
internal class DnsTcpMessageHandler : DnsMessageHandler
{
private bool _disposedValue = false;
private readonly ConcurrentDictionary<IPEndPoint, ClientPool> _pools = new ConcurrentDictionary<IPEndPoint, ClientPool>();

public override DnsMessageHandleType Type { get; } = DnsMessageHandleType.TCP;

public override DnsResponseMessage Query(IPEndPoint server, DnsRequestMessage request, TimeSpan timeout)
{
CancellationToken cancellationToken = default;
if (_disposedValue)
{
throw new ObjectDisposedException(nameof(DnsTcpMessageHandler));
}

using var cts = timeout.TotalMilliseconds != Timeout.Infinite && timeout.TotalMilliseconds < int.MaxValue ?
new CancellationTokenSource(timeout) : null;

cancellationToken = cts?.Token ?? default;
var cancellationToken = cts?.Token ?? default;

ClientPool pool;
while (!_pools.TryGetValue(server, out pool))
Expand All @@ -32,7 +36,7 @@ public override DnsResponseMessage Query(IPEndPoint server, DnsRequestMessage re

cancellationToken.ThrowIfCancellationRequested();

var entry = pool.GetNextClient();
var entry = pool.GetNextClient(cancellationToken);

using var cancelCallback = cancellationToken.Register(() =>
{
Expand Down Expand Up @@ -69,6 +73,11 @@ public override async Task<DnsResponseMessage> QueryAsync(
DnsRequestMessage request,
CancellationToken cancellationToken)
{
if (_disposedValue)
{
throw new ObjectDisposedException(nameof(DnsTcpMessageHandler));
}

cancellationToken.ThrowIfCancellationRequested();

ClientPool pool;
Expand All @@ -77,7 +86,7 @@ public override async Task<DnsResponseMessage> QueryAsync(
_pools.TryAdd(server, new ClientPool(true, server));
}

var entry = await pool.GetNextClientAsync().ConfigureAwait(false);
var entry = await pool.GetNextClientAsync(cancellationToken).ConfigureAwait(false);

using var cancelCallback = cancellationToken.Register(() =>
{
Expand Down Expand Up @@ -281,7 +290,22 @@ private async Task<DnsResponseMessage> QueryAsyncInternal(TcpClient client, DnsR
return DnsResponseMessage.Combine(responses);
}

private class ClientPool : IDisposable
protected override void Dispose(bool disposing)
{
if (disposing && !_disposedValue)
{
_disposedValue = true;

foreach (var entry in _pools)
{
entry.Value.Dispose();
}
}

base.Dispose(disposing);
}

private sealed class ClientPool : IDisposable
{
private bool _disposedValue = false;
private readonly bool _enablePool;
Expand All @@ -294,7 +318,7 @@ public ClientPool(bool enablePool, IPEndPoint endpoint)
_endpoint = endpoint;
}

public ClientEntry GetNextClient()
public ClientEntry GetNextClient(CancellationToken cancellationToken)
{
if (_disposedValue)
{
Expand All @@ -306,20 +330,54 @@ public ClientEntry GetNextClient()
{
while (entry == null && !TryDequeue(out entry))
{
entry = new ClientEntry(new TcpClient(_endpoint.AddressFamily) { LingerState = new LingerOption(true, 0) }, _endpoint);
entry.Client.Connect(_endpoint.Address, _endpoint.Port);
entry = ConnectNew(cancellationToken);
}
}
else
{
entry = new ClientEntry(new TcpClient(_endpoint.AddressFamily), _endpoint);
entry.Client.Connect(_endpoint.Address, _endpoint.Port);
entry = ConnectNew(cancellationToken);
}

return entry;
}

public async Task<ClientEntry> GetNextClientAsync()
private ClientEntry ConnectNew(CancellationToken cancellationToken)
{
var newClient = new TcpClient(_endpoint.AddressFamily)
{
LingerState = new LingerOption(true, 0)
};

bool gotCanceled = false;
cancellationToken.Register(() =>
{
gotCanceled = true;
newClient.Dispose();
});

try
{
newClient.Connect(_endpoint.Address, _endpoint.Port);
}
catch (Exception) when (gotCanceled)
{
throw new OperationCanceledException("Connection timed out.", cancellationToken);
}
catch (Exception)
{
try
{
newClient.Dispose();
}
catch { }

throw;
}

return new ClientEntry(newClient, _endpoint);
}

public async Task<ClientEntry> GetNextClientAsync(CancellationToken cancellationToken)
{
if (_disposedValue)
{
Expand All @@ -331,19 +389,57 @@ public async Task<ClientEntry> GetNextClientAsync()
{
while (entry == null && !TryDequeue(out entry))
{
entry = new ClientEntry(new TcpClient(_endpoint.AddressFamily) { LingerState = new LingerOption(true, 0) }, _endpoint);
await entry.Client.ConnectAsync(_endpoint.Address, _endpoint.Port).ConfigureAwait(false);
entry = await ConnectNewAsync(cancellationToken).ConfigureAwait(false);
}
}
else
{
entry = new ClientEntry(new TcpClient(_endpoint.AddressFamily), _endpoint);
await entry.Client.ConnectAsync(_endpoint.Address, _endpoint.Port).ConfigureAwait(false);
entry = await ConnectNewAsync(cancellationToken).ConfigureAwait(false);
}

return entry;
}

private async Task<ClientEntry> ConnectNewAsync(CancellationToken cancellationToken)
{
var newClient = new TcpClient(_endpoint.AddressFamily)
{
LingerState = new LingerOption(true, 0)
};

#if NET6_0_OR_GREATER
await newClient.ConnectAsync(_endpoint.Address, _endpoint.Port, cancellationToken).ConfigureAwait(false);
#else

bool gotCanceled = false;
cancellationToken.Register(() =>
{
gotCanceled = true;
newClient.Dispose();
});

try
{
await newClient.ConnectAsync(_endpoint.Address, _endpoint.Port).ConfigureAwait(false);
}
catch (Exception) when (gotCanceled)
{
throw new OperationCanceledException("Connection timed out.", cancellationToken);
}
catch (Exception)
{
try
{
newClient.Dispose();
}
catch { }

throw;
}
#endif
return new ClientEntry(newClient, _endpoint);
}

public void Enqueue(ClientEntry entry)
{
if (_disposedValue)
Expand Down Expand Up @@ -397,29 +493,20 @@ public bool TryDequeue(out ClientEntry entry)
return result;
}

protected virtual void Dispose(bool disposing)
public void Dispose()
{
if (!_disposedValue)
{
if (disposing)
_disposedValue = true;
foreach (var entry in _clients)
{
foreach (var entry in _clients)
{
entry.DisposeClient();
}

_clients = new ConcurrentQueue<ClientEntry>();
entry.DisposeClient();
}

_disposedValue = true;
_clients = new ConcurrentQueue<ClientEntry>();
}
}

public void Dispose()
{
Dispose(true);
}

public class ClientEntry
{
public ClientEntry(TcpClient client, IPEndPoint endpoint)
Expand All @@ -432,11 +519,7 @@ public void DisposeClient()
{
try
{
#if !NET45
Client.Dispose();
#else
Client.Close();
#endif
}
catch { }
}
Expand Down
1 change: 0 additions & 1 deletion src/DnsClient/DnsUdpMessageHandler.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System;
using System.Buffers;
using System.Collections.Concurrent;
using System.Net;
using System.Net.Sockets;
using System.Threading;
Expand Down
2 changes: 1 addition & 1 deletion src/DnsClient/ILookupClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ public interface ILookupClient : IDnsQuery

#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
}
}
}
20 changes: 17 additions & 3 deletions src/DnsClient/LookupClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace DnsClient
/// ]]>
/// </code>
/// </example>
public class LookupClient : ILookupClient, IDnsQuery
public sealed class LookupClient : ILookupClient, IDnsQuery, IDisposable
{
private const int LogEventStartQuery = 1;
private const int LogEventQuery = 2;
Expand All @@ -62,6 +62,7 @@ public class LookupClient : ILookupClient, IDnsQuery
private readonly SkipWorker _skipper = null;

private IReadOnlyCollection<NameServer> _resolvedNameServers;
private bool _disposedValue;

/// <inheritdoc/>
public IReadOnlyCollection<NameServer> NameServers => Settings.NameServers;
Expand Down Expand Up @@ -370,7 +371,7 @@ internal LookupClient(LookupClientOptions options, DnsMessageHandler udpHandler

// Setting up name servers.
// Using manually configured ones and/or auto resolved ones.
IReadOnlyCollection<NameServer> servers = _originalOptions.NameServers?.ToArray() ?? new NameServer[0];
IReadOnlyCollection<NameServer> servers = _originalOptions.NameServers?.ToArray() ?? Array.Empty<NameServer>();

if (options.AutoResolveNameServers)
{
Expand Down Expand Up @@ -427,7 +428,9 @@ private void CheckResolvedNameservers()
}

_resolvedNameServers = newServers;
var servers = _originalOptions.NameServers.Concat(_resolvedNameServers).ToArray();
IReadOnlyCollection<NameServer> servers = _originalOptions.NameServers.Concat(_resolvedNameServers).ToArray();
servers = NameServer.ValidateNameServers(servers, _logger);

Settings = new LookupClientSettings(_originalOptions, servers);
}
catch (Exception ex)
Expand Down Expand Up @@ -1787,6 +1790,17 @@ public void MaybeDoWork()
}
}
}

/// <inheritdoc/>
public void Dispose()
{
if (!_disposedValue)
{
_disposedValue = true;
_tcpFallbackHandler?.Dispose();
_messageHandler?.Dispose();
}
}
}

internal class LookupClientAudit
Expand Down
4 changes: 2 additions & 2 deletions test-other/OldReference/TestLookupClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public LookupClient SetNonDefaults()

public void TestQuery_1_1()
{
var client = new LookupClient();
var client = new LookupClient(NameServer.GooglePublicDns.Address);
client.Query("domain", QueryType.A);
client.Query("domain", QueryType.A, QueryClass.IN);
client.QueryReverse(IPAddress.Loopback);
Expand All @@ -71,7 +71,7 @@ public void TestQuery_1_1()

public async Task TestQueryAsync_1_1()
{
var client = new LookupClient();
var client = new LookupClient(NameServer.GooglePublicDns.Address);
await client.QueryAsync("domain", QueryType.A).ConfigureAwait(false);
await client.QueryAsync("domain", QueryType.A, QueryClass.IN).ConfigureAwait(false);
await client.QueryAsync("domain", QueryType.A, cancellationToken: default).ConfigureAwait(false);
Expand Down
Loading