Skip to content

Commit

Permalink
tls_available support (#889)
Browse files Browse the repository at this point in the history
  • Loading branch information
scottf authored Apr 30, 2024
1 parent 7470c56 commit f704a94
Showing 10 changed files with 344 additions and 42 deletions.
59 changes: 33 additions & 26 deletions src/NATS.Client/Connection.cs
Original file line number Diff line number Diff line change
@@ -133,7 +133,7 @@ public class Connection : IConnection
bool flusherKicked = false;
bool flusherDone = false;

private ServerInfo info = null;
private ServerInfo serverInfo = null;

private Dictionary<Int64, Subscription> subs =
new Dictionary<Int64, Subscription>();
@@ -1035,7 +1035,7 @@ public IPAddress ClientIP
string clientIp;
lock (mu)
{
clientIp = info.ClientIp;
clientIp = serverInfo.ClientIp;
}

return !string.IsNullOrEmpty(clientIp) ? IPAddress.Parse(clientIp) : null;
@@ -1055,7 +1055,7 @@ public int ClientID
{
lock (mu)
{
return info.ClientId;
return serverInfo.ClientId;
}
}
}
@@ -1073,7 +1073,7 @@ public string ConnectedId
if (status != ConnState.CONNECTED)
return IC._EMPTY_;

return this.info.ServerId;
return this.serverInfo.ServerId;
}
}
}
@@ -1088,7 +1088,7 @@ public ServerInfo ServerInfo
{
lock (mu)
{
return status == ConnState.CONNECTED ? info : null;
return status == ConnState.CONNECTED ? serverInfo : null;
}
}
}
@@ -1281,25 +1281,31 @@ internal void connect(bool reconnectOnConnect)
// only be called after the INIT protocol has been received.
private void checkForSecure(Srv s)
{
if (!Opts.TlsFirst)
bool makeTlsConn = false;
if (Opts.TlsFirst)
{
// Check to see if we need to engage TLS
// Check for mismatch in setups
if (Opts.Secure && !info.TlsRequired)
makeTlsConn = true;
}
else
{
if (Opts.Secure || s.Secure)
{
throw new NATSSecureConnWantedException();
// Check to see if the client wants tls but the server doesn't
if (!serverInfo.TlsRequired && !serverInfo.TlsAvailable)
{
throw new NATSSecureConnWantedException();
}
makeTlsConn = true;
}
else if (info.TlsRequired && !Opts.Secure)
else if (serverInfo.TlsRequired)
{
// If the server asks us to be secure, give it
// a shot.
Opts.Secure = true;
// on some clients we error if the client isn't secure but the server is
// but in this client, we just try it. It will essentially fail slow instead of fast
makeTlsConn = true;
}
}

// Need to rewrap with bufio if options tell us we need
// a secure connection or the tls url scheme was specified.
if (Opts.Secure || s.Secure)
if (makeTlsConn)
{
makeTLSConn();
}
@@ -1433,7 +1439,7 @@ private string connectProto()
throw new NATSConnectionException("User signature event handle has not been been defined.");
}

var args = new UserSignatureEventArgs(Encoding.ASCII.GetBytes(this.info.Nonce));
var args = new UserSignatureEventArgs(Encoding.ASCII.GetBytes(this.serverInfo.Nonce));
try
{
opts.UserSignatureEventHandler(this, args);
@@ -2416,21 +2422,22 @@ internal void processOK()
// processInfo is used to parse the info messages sent
// from the server.
// Caller must lock.
internal void processInfo(string json, bool notify)
// made virtual for unit testing
internal virtual void processInfo(string json, bool notify)
{
if (json == null || IC._EMPTY_.Equals(json))
{
return;
}

info = new ServerInfo(json);
var serverAdded = srvProvider.AcceptDiscoveredServers(info.ConnectURLs);
serverInfo = new ServerInfo(json);
var serverAdded = srvProvider.AcceptDiscoveredServers(serverInfo.ConnectURLs);
if (notify && serverAdded)
{
scheduleConnEvent(opts.ServerDiscoveredEventHandlerOrDefault);
}

if (notify && info.LameDuckMode)
if (notify && serverInfo.LameDuckMode)
{
scheduleConnEvent(opts.LameDuckModeEventHandlerOrDefault);
}
@@ -2633,7 +2640,7 @@ internal void PublishImpl(string subject, string reply, MsgHeader inHeaders, byt
byte[] headers = null;
if (inHeaders != null)
{
if (!info.HeadersSupported)
if (!serverInfo.HeadersSupported)
{
throw new NATSNotSupportedException("Headers are not supported by the server.");
}
@@ -2649,7 +2656,7 @@ internal void PublishImpl(string subject, string reply, MsgHeader inHeaders, byt
throw new NATSConnectionDrainingException();

// Proactively reject payloads over the threshold set by server.
if (opts.ClientSideLimitChecks && count > info.MaxPayload && info.MaxPayload > 0)
if (opts.ClientSideLimitChecks && count > serverInfo.MaxPayload && serverInfo.MaxPayload > 0)
{
throw new NATSMaxPayloadException();
}
@@ -5172,7 +5179,7 @@ public long MaxPayload
{
lock (mu)
{
return info.MaxPayload;
return serverInfo.MaxPayload;
}
}
}
@@ -5187,7 +5194,7 @@ public override string ToString()
StringBuilder sb = new StringBuilder();
sb.Append("{");
sb.AppendFormat("url={0};", url);
sb.AppendFormat("info={0};", info);
sb.AppendFormat("info={0};", serverInfo);
sb.AppendFormat("status={0};", status);
sb.Append("Subscriptions={");
foreach (Subscription s in subs.Values)
2 changes: 2 additions & 0 deletions src/NATS.Client/JetStream/ApiConstants.cs
Original file line number Diff line number Diff line change
@@ -193,6 +193,8 @@ public static class ApiConstants
public const string Time = "time";
public const string Timestamp = "ts";
public const string Tls = "tls_required";
public const string TlsRequired = Tls;
public const string TlsAvailable = "tls_available";
public const string Total = "total";
public const string Type = "type";
public const string Version = "version";
4 changes: 3 additions & 1 deletion src/NATS.Client/JetStream/ServerInfo.cs
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@ public sealed class ServerInfo
public bool HeadersSupported { get; }
public bool AuthRequired { get; }
public bool TlsRequired { get; }
public bool TlsAvailable { get; }
public long MaxPayload { get; }
public string[] ConnectURLs { get; }
public int ProtocolVersion { get; }
@@ -54,7 +55,8 @@ public ServerInfo(string json)
HeadersSupported = siNode[ApiConstants.Headers].AsBool;
AuthRequired = siNode[ApiConstants.AuthRequired].AsBool;
Nonce = siNode[ApiConstants.Nonce].Value;
TlsRequired = siNode[ApiConstants.Tls].AsBool;
TlsRequired = siNode[ApiConstants.TlsRequired].AsBool;
TlsAvailable = siNode[ApiConstants.TlsAvailable].AsBool;
LameDuckMode = siNode[ApiConstants.LameDuckMode].AsBool;
JetStreamAvailable = siNode[ApiConstants.Jetstream].AsBool;
Port = siNode[ApiConstants.Port].AsInt;
98 changes: 98 additions & 0 deletions src/NATS.Client/Options.cs
Original file line number Diff line number Diff line change
@@ -49,6 +49,104 @@ public sealed class Options

internal X509Certificate2Collection certificates = null;

private bool Equals(Options other)
{
return url == other.url
&& Equals(servers, other.servers)
&& noRandomize == other.noRandomize
&& name == other.name
&& verbose == other.verbose
&& pedantic == other.pedantic
&& useOldRequestStyle == other.useOldRequestStyle
&& secure == other.secure
&& allowReconnect == other.allowReconnect
&& noEcho == other.noEcho
&& ignoreDiscoveredServers == other.ignoreDiscoveredServers
&& tlsFirst == other.tlsFirst
&& clientSideLimitChecks == other.clientSideLimitChecks
&& maxReconnect == other.maxReconnect
&& reconnectWait == other.reconnectWait
&& pingInterval == other.pingInterval
&& timeout == other.timeout
&& reconnectJitter == other.reconnectJitter
&& reconnectJitterTLS == other.reconnectJitterTLS
&& Equals(certificates, other.certificates)
&& maxPingsOut == other.maxPingsOut
&& pendingMessageLimit == other.pendingMessageLimit
&& pendingBytesLimit == other.pendingBytesLimit
&& subscriberDeliveryTaskCount == other.subscriberDeliveryTaskCount
&& subscriptionBatchSize == other.subscriptionBatchSize
&& reconnectBufSize == other.reconnectBufSize
&& user == other.user
&& password == other.password
&& token == other.token
&& nkey == other.nkey
&& customInboxPrefix == other.customInboxPrefix
&& CheckCertificateRevocation == other.CheckCertificateRevocation;
}

public override bool Equals(object obj)
{
return ReferenceEquals(this, obj) || obj is Options other && Equals(other);
}

public override int GetHashCode()
{
unchecked
{
int hashCode = (url != null ? url.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (servers != null ? servers.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ noRandomize.GetHashCode();
hashCode = (hashCode * 397) ^ (name != null ? name.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ verbose.GetHashCode();
hashCode = (hashCode * 397) ^ pedantic.GetHashCode();
hashCode = (hashCode * 397) ^ useOldRequestStyle.GetHashCode();
hashCode = (hashCode * 397) ^ secure.GetHashCode();
hashCode = (hashCode * 397) ^ allowReconnect.GetHashCode();
hashCode = (hashCode * 397) ^ noEcho.GetHashCode();
hashCode = (hashCode * 397) ^ ignoreDiscoveredServers.GetHashCode();
hashCode = (hashCode * 397) ^ tlsFirst.GetHashCode();
hashCode = (hashCode * 397) ^ clientSideLimitChecks.GetHashCode();
hashCode = (hashCode * 397) ^ (serverProvider != null ? serverProvider.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ maxReconnect;
hashCode = (hashCode * 397) ^ reconnectWait;
hashCode = (hashCode * 397) ^ pingInterval;
hashCode = (hashCode * 397) ^ timeout;
hashCode = (hashCode * 397) ^ reconnectJitter;
hashCode = (hashCode * 397) ^ reconnectJitterTLS;
hashCode = (hashCode * 397) ^ (tcpConnection != null ? tcpConnection.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (certificates != null ? certificates.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (ClosedEventHandler != null ? ClosedEventHandler.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (ServerDiscoveredEventHandler != null ? ServerDiscoveredEventHandler.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (DisconnectedEventHandler != null ? DisconnectedEventHandler.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (ReconnectedEventHandler != null ? ReconnectedEventHandler.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (AsyncErrorEventHandler != null ? AsyncErrorEventHandler.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (LameDuckModeEventHandler != null ? LameDuckModeEventHandler.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (ReconnectDelayHandler != null ? ReconnectDelayHandler.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (HeartbeatAlarmEventHandler != null ? HeartbeatAlarmEventHandler.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (UnhandledStatusEventHandler != null ? UnhandledStatusEventHandler.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (PullStatusWarningEventHandler != null ? PullStatusWarningEventHandler.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (PullStatusErrorEventHandler != null ? PullStatusErrorEventHandler.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (FlowControlProcessedEventHandler != null ? FlowControlProcessedEventHandler.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (UserJWTEventHandler != null ? UserJWTEventHandler.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (UserSignatureEventHandler != null ? UserSignatureEventHandler.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ maxPingsOut;
hashCode = (hashCode * 397) ^ pendingMessageLimit.GetHashCode();
hashCode = (hashCode * 397) ^ pendingBytesLimit.GetHashCode();
hashCode = (hashCode * 397) ^ subscriberDeliveryTaskCount;
hashCode = (hashCode * 397) ^ subscriptionBatchSize;
hashCode = (hashCode * 397) ^ reconnectBufSize;
hashCode = (hashCode * 397) ^ (user != null ? user.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (password != null ? password.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (token != null ? token.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (nkey != null ? nkey.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (customInboxPrefix != null ? customInboxPrefix.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (TLSRemoteCertificationValidationCallback != null ? TLSRemoteCertificationValidationCallback.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ CheckCertificateRevocation.GetHashCode();
return hashCode;
}
}

/// <summary>
/// Represents the method that will handle an event raised
/// when a connection is closed.
16 changes: 16 additions & 0 deletions src/Samples/JetStreamStarter/JetStreamStarter.cs
Original file line number Diff line number Diff line change
@@ -20,6 +20,22 @@

namespace NATSExamples
{
class Virt
{
internal virtual void foo()
{
Console.WriteLine("Virt Foo");
}
}

class VirtEx : Virt
{
internal override void foo()
{
Console.WriteLine("VirtEx Foo");
base.foo();
}
}
internal static class JetStreamStarter
{
static void Main(string[] args)
18 changes: 4 additions & 14 deletions src/Tests/IntegrationTests/TestTLS.cs
Original file line number Diff line number Diff line change
@@ -29,14 +29,7 @@ namespace IntegrationTests
/// </summary>
public class TestTls : TestSuite<TlsSuiteContext>
{
// public TestTls(TlsSuiteContext context) : base(context) { }
private readonly ITestOutputHelper output;

public TestTls(ITestOutputHelper output, TlsSuiteContext context) : base(context)
{
this.output = output;
Console.SetOut(new TestBase.ConsoleWriter(output));
}
public TestTls(TlsSuiteContext context) : base(context) { }

// A hack to avoid issues with our test self signed cert.
// We don't want to require the runner of the test to install the
@@ -58,14 +51,10 @@ private bool verifyServerCert(object sender,

// UNSAFE hack for testing purposes.
#if NET46
var isOK = serverCert.GetRawCertDataString().Equals(certificate.GetRawCertDataString());
return serverCert.GetRawCertDataString().Equals(certificate.GetRawCertDataString());
#else
var isOK = serverCert.Issuer.Equals(certificate.Issuer);
return serverCert.Issuer.Equals(certificate.Issuer);
#endif
if (isOK)
return true;

return false;
}

[Fact]
@@ -148,6 +137,7 @@ public void TestTlsFailWithBadAuth()
using (NATSServer srv = NATSServer.CreateWithConfig(Context.Server1.Port, "tls_user.conf"))
{
Options opts = Context.GetTestOptions(Context.Server1.Port);
opts.Timeout = 10000;
opts.Secure = true;
opts.Url = $"nats://username:BADDPASSOWRD@localhost:{Context.Server1.Port}";
opts.TLSRemoteCertificationValidationCallback = verifyServerCert;
12 changes: 12 additions & 0 deletions src/Tests/IntegrationTestsInternal/IntegrationTestsInternal.csproj
Original file line number Diff line number Diff line change
@@ -25,4 +25,16 @@
<ProjectReference Include="..\IntegrationTests\IntegrationTests.csproj" />
</ItemGroup>

<ItemGroup>
<None Update="config\certs\server-cert.pem">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
<None Update="config\tls.conf">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
<None Update="config\tls_first.conf">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
</ItemGroup>

</Project>
Loading

0 comments on commit f704a94

Please sign in to comment.