Skip to content

Commit

Permalink
Improve TlsFrameHelper (#2587)
Browse files Browse the repository at this point in the history
Co-authored-by: Tomas Weinfurt <[email protected]>
  • Loading branch information
MihaZupan and wfurt authored Aug 27, 2024
1 parent a9565ba commit 555ca0f
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 16 deletions.
57 changes: 48 additions & 9 deletions src/ReverseProxy/Utilities/TlsFrameHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ public enum ApplicationProtocolInfo
Other = 128
}

public enum ParsingStatus
{
Ok = 0,
IncompleteFrame = 1,
InvalidFrame = 2,
UnsupportedFrame = 3,
}

public struct TlsFrameInfo
{
internal TlsCipherSuite[]? _ciphers;
Expand All @@ -127,6 +135,7 @@ public struct TlsFrameInfo
public string TargetName;
public ApplicationProtocolInfo ApplicationProtocols;
public TlsAlertDescription AlertDescription;
public ParsingStatus ParsingStatus;
public ReadOnlyMemory<TlsCipherSuite> TlsCipherSuites
{
get
Expand Down Expand Up @@ -258,6 +267,7 @@ public static bool TryGetFrameInfo(ReadOnlySpan<byte> frame, ref TlsFrameInfo in
const int HandshakeTypeOffset = 5;
if (frame.Length < HeaderSize)
{
info.ParsingStatus = ParsingStatus.IncompleteFrame;
return false;
}

Expand Down Expand Up @@ -286,28 +296,37 @@ public static bool TryGetFrameInfo(ReadOnlySpan<byte> frame, ref TlsFrameInfo in
if (TryGetAlertInfo(frame, ref level, ref description))
{
info.AlertDescription = description;
info.ParsingStatus = ParsingStatus.Ok;
return true;
}

info.ParsingStatus = ParsingStatus.IncompleteFrame;
return false;
}

if (info.Header.Type != TlsContentType.Handshake || frame.Length <= HandshakeTypeOffset)
if (info.Header.Type != TlsContentType.Handshake)
{
info.ParsingStatus = ParsingStatus.UnsupportedFrame;
return false;
}

info.HandshakeType = (TlsHandshakeType)frame[HandshakeTypeOffset];
if (frame.Length <= HandshakeTypeOffset)
{
info.ParsingStatus = ParsingStatus.IncompleteFrame;
return false;
}

info.HandshakeType = (TlsHandshakeType)frame[HandshakeTypeOffset];
// Check if we have full frame.
var isComplete = frame.Length >= HeaderSize + info.Header.Length;
info.ParsingStatus = isComplete ? ParsingStatus.Ok : ParsingStatus.IncompleteFrame;

#pragma warning disable SYSLIB0039 // TLS 1.0 and 1.1 are obsolete
if (((int)info.Header.Version >= (int)SslProtocols.Tls) &&
#pragma warning restore SYSLIB0039
(info.HandshakeType == TlsHandshakeType.ClientHello || info.HandshakeType == TlsHandshakeType.ServerHello))
{
if (!TryParseHelloFrame(frame.Slice(HeaderSize), ref info, options, callback))
if (!TryParseHelloFrame(frame.Slice(HeaderSize, Math.Min(info.Header.Length, frame.Length - HeaderSize)), ref info, options, callback))
{
isComplete = false;
}
Expand Down Expand Up @@ -404,19 +423,39 @@ private static bool TryParseHelloFrame(ReadOnlySpan<byte> sslHandshake, ref TlsF
const int HandshakeTypeOffset = 0;
const int HelloLengthOffset = HandshakeTypeOffset + sizeof(TlsHandshakeType);
const int HelloOffset = HelloLengthOffset + UInt24Size;
const int HandshakeHeaderLength = 4; // Type and Handshake length
const int MinimalHandshakeLength = 44; // Version, Random, SessionID and Cipher length with at least one cipher

if (sslHandshake.Length < HelloOffset ||
((TlsHandshakeType)sslHandshake[HandshakeTypeOffset] != TlsHandshakeType.ClientHello &&
(TlsHandshakeType)sslHandshake[HandshakeTypeOffset] != TlsHandshakeType.ServerHello))
if (info.Header.Length - HandshakeHeaderLength < MinimalHandshakeLength)
{
info.ParsingStatus = ParsingStatus.InvalidFrame;
return false;
}

if (sslHandshake.Length < HelloOffset + 3)
{
info.ParsingStatus = ParsingStatus.IncompleteFrame;
return false;
}

if ((TlsHandshakeType)sslHandshake[HandshakeTypeOffset] != TlsHandshakeType.ClientHello &&
(TlsHandshakeType)sslHandshake[HandshakeTypeOffset] != TlsHandshakeType.ServerHello)
{
info.ParsingStatus = ParsingStatus.UnsupportedFrame;
return false;
}

var helloLength = ReadUInt24BigEndian(sslHandshake.Slice(HelloLengthOffset));
var helloData = sslHandshake.Slice(HelloOffset);
if (helloLength < MinimalHandshakeLength || helloLength > info.Header.Length - HandshakeHeaderLength)
{
info.ParsingStatus = ParsingStatus.InvalidFrame;
return false;
}

var helloData = sslHandshake.Slice(HelloOffset);
if (helloData.Length < helloLength)
{
info.ParsingStatus = ParsingStatus.IncompleteFrame;
return false;
}

Expand Down Expand Up @@ -490,12 +529,12 @@ private static bool TryParseServerHello(ReadOnlySpan<byte> serverHello, ref TlsF
// }
// ServerHello;
const int CipherSuiteLength = 2;
const int CompressionMethiodLength = 1;
const int CompressionMethodLength = 1;

var p = SkipBytes(serverHello, ProtocolVersionSize + RandomSize);
// Skip SessionID (max size 32 => size fits in 1 byte)
p = SkipOpaqueType1(p);
p = SkipBytes(p, CipherSuiteLength + CompressionMethiodLength);
p = SkipBytes(p, CipherSuiteLength + CompressionMethodLength);

// is invalid structure or no extensions?
if (p.IsEmpty)
Expand Down
99 changes: 92 additions & 7 deletions test/ReverseProxy.Tests/Utilities/TlsFrameHelperTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ public void TlsFrameHelper_TlsClientHelloNoExtensions_Ok()
{
TlsFrameHelper.TlsFrameInfo info = default;
Assert.True(TlsFrameHelper.TryGetFrameInfo(s_TlsClientHelloNoExtensions, ref info));
Assert.Equal(TlsFrameHelper.ParsingStatus.Ok, info.ParsingStatus);
Assert.Equal(SslProtocols.Tls12, info.Header.Version);
Assert.Equal(SslProtocols.Tls12, info.SupportedVersions);
Assert.Equal(TlsContentType.Handshake, info.Header.Type);
Expand All @@ -142,12 +143,20 @@ public void TlsFrameHelper_Tls12ServerHello_Ok()
{
TlsFrameHelper.TlsFrameInfo info = default;
Assert.True(TlsFrameHelper.TryGetFrameInfo(s_Tls12ServerHello, ref info));

Assert.Equal(TlsFrameHelper.ParsingStatus.Ok, info.ParsingStatus);
Assert.Equal(SslProtocols.Tls12, info.Header.Version);
Assert.Equal(SslProtocols.Tls12, info.SupportedVersions);
Assert.Equal(TlsFrameHelper.ApplicationProtocolInfo.Http2, info.ApplicationProtocols);
}

[Fact]
public void TlsFrameHelper_FragmentedClientHello_Fails()
{
TlsFrameHelper.TlsFrameInfo info = default;
Assert.False(TlsFrameHelper.TryGetFrameInfo(s_Tls13FragmentedClientHello, ref info));
Assert.Equal(TlsFrameHelper.ParsingStatus.InvalidFrame, info.ParsingStatus);
}

public static IEnumerable<object[]> InvalidClientHelloData()
{
int id = 0;
Expand Down Expand Up @@ -183,7 +192,7 @@ public static IEnumerable<Tuple<int, byte[]>> InvalidClientHelloDataTruncatedByt
}
}

private static byte[] s_validClientHello = new byte[] {
private static readonly byte[] s_validClientHello = new byte[] {
// SslPlainText.(ContentType+ProtocolVersion)
0x16, 0x03, 0x03,
// SslPlainText.length
Expand Down Expand Up @@ -271,7 +280,7 @@ public static IEnumerable<Tuple<int, byte[]>> InvalidClientHelloDataTruncatedByt
0x00, 0x01, 0x00
};

private static byte[] s_Tls12ClientHello = new byte[] {
private static readonly byte[] s_Tls12ClientHello = new byte[] {
// SslPlainText.(ContentType+ProtocolVersion)
0x16, 0x03, 0x01,
// SslPlainText.length
Expand Down Expand Up @@ -329,7 +338,7 @@ public static IEnumerable<Tuple<int, byte[]>> InvalidClientHelloDataTruncatedByt
0x2E, 0x31
};

private static byte[] s_Tls13ClientHello = new byte[] {
private static readonly byte[] s_Tls13ClientHello = new byte[] {
// SslPlainText.(ContentType+ProtocolVersion)
0x16, 0x03, 0x01,
// SslPlainText.length
Expand Down Expand Up @@ -399,7 +408,7 @@ public static IEnumerable<Tuple<int, byte[]>> InvalidClientHelloDataTruncatedByt
0x03, 0x01, 0x03, 0x03, 0x02, 0x01, 0x02, 0x03
};

private static byte[] s_Tls12ServerHello = new byte[] {
private static readonly byte[] s_Tls12ServerHello = new byte[] {
// SslPlainText.(ContentType+ProtocolVersion)
0x16, 0x03, 0x03,
// SslPlainText.length
Expand Down Expand Up @@ -441,7 +450,7 @@ public static IEnumerable<Tuple<int, byte[]>> InvalidClientHelloDataTruncatedByt
0x00, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, 0x68, 0x32,
};

private static byte[] s_UnifiedHello = new byte[]
private static readonly byte[] s_UnifiedHello = new byte[]
{
// Length
0x80, 0x49,
Expand All @@ -460,7 +469,7 @@ public static IEnumerable<Tuple<int, byte[]>> InvalidClientHelloDataTruncatedByt
0x52, 0x3B, 0x12, 0x9C, 0xF8, 0xD4,
};

private static byte[] s_TlsClientHelloNoExtensions = new byte[] {
private static readonly byte[] s_TlsClientHelloNoExtensions = new byte[] {
0x16, 0x03, 0x03, 0x00, 0x39, 0x01, 0x00, 0x00,
0x35, 0x03, 0x03, 0x62, 0x5d, 0x50, 0x2a, 0x41,
0x2f, 0xd8, 0xc3, 0x65, 0x35, 0xea, 0x01, 0x70,
Expand All @@ -471,6 +480,82 @@ public static IEnumerable<Tuple<int, byte[]>> InvalidClientHelloDataTruncatedByt
0x00, 0x05, 0x00, 0x04, 0x01, 0x00
};

private static readonly byte[] s_Tls13FragmentedClientHello = new byte[] {
// SslPlainText.(ContentType+ProtocolVersion)
0x16, 0x03, 0x01,
// SslPlainText.length
0x00, 0x04, // Fragmented
// Handshake.msg_type (client hello)
0x01,
// Handshake.length
0x00, 0x01, 0x04,
// Extra fragment header
// SslPlainText.(ContentType+ProtocolVersion)
0x16, 0x03, 0x01,
// SslPlainText.length
0x01, 0x04,
// ClientHello.client_version
0x03, 0x03,

// ClientHello.random
0x0C, 0x3C, 0x85, 0x78, 0xCA, 0x67, 0x70, 0xAA,
0x38, 0xCB, 0x28, 0xBC, 0xDC, 0x3E, 0x30, 0xBF,
0x11, 0x96, 0x95, 0x1A, 0xB9, 0xF0, 0x99, 0xA4,
0x91, 0x09, 0x13, 0xB4, 0x89, 0x94, 0x27, 0x2E,
// ClientHello.SessionId_Length
0x20,
// ClientHello.SessionId
0x0C, 0x3C, 0x85, 0x78, 0xCA, 0x67, 0x70, 0xAA,
0x38, 0xCB, 0x28, 0xBC, 0xDC, 0x3E, 0x30, 0xBF,
0x11, 0x96, 0x95, 0x1A, 0xB9, 0xF0, 0x99, 0xA4,
0x91, 0x09, 0x13, 0xB4, 0x89, 0x94, 0x27, 0x2E,
// ClientHello.cipher_suites_length
0x00, 0x0C,
// ClientHello.cipher_suites
0x13, 0x02, 0x13, 0x03, 0x13, 0x01, 0xC0, 0x14,
0xc0, 0x30, 0x00, 0xFF,
// ClientHello.compression_methods
0x01, 0x00,
// ClientHello.extension_list_length
0x00, 0xAF,
// Extension.extension_type (server_name) (10.211.55.2)
0x00, 0x00, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x00,
0x0B, 0x31, 0x30, 0x2E, 0x32, 0x31, 0x31, 0x2E,
0x35, 0x35, 0x2E, 0x32,
// Extension.extension_type (ec_point_formats)
0x00, 0x0B, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02,
// Extension.extension_type (supported_groups)
0x00, 0x0A, 0x00, 0x0C, 0x00, 0x0A, 0x00, 0x1D,
0x00, 0x17, 0x00, 0x1E, 0x00, 0x19, 0x00, 0x18,
// Extension.extension_type (application_level_Protocol) (boo)
0x00, 0x10, 0x00, 0x06, 0x00, 0x04, 0x03, 0x62,
0x6f, 0x6f,
// Extension.extension_type (encrypt_then_mac)
0x00, 0x16, 0x00, 0x00,
// Extension.extension_type (extended_master_key_secret)
0x00, 0x17, 0x00, 0x00,
// Extension.extension_type (signature_algorithms)
0x00, 0x0D, 0x00, 0x30, 0x00, 0x2E,
0x06, 0x03, 0xEF, 0xEF, 0x05, 0x01, 0x05, 0x03,
0x06, 0x03, 0xEF, 0xEF, 0x05, 0x01, 0x05, 0x03,
0x06, 0x03, 0xEF, 0xEF, 0x05, 0x01, 0x05, 0x03,
0x04, 0x01, 0x04, 0x03, 0xEE, 0xEE, 0xED, 0xED,
0x03, 0x01, 0x03, 0x03, 0x02, 0x01, 0x02, 0x03,
0x03, 0x01, 0x03, 0x03, 0x02, 0x01,
// Extension.extension_type (supported_versions)
0x00, 0x2B, 0x00, 0x09, 0x08, 0x03, 0x04, 0x03,
0x03, 0x03, 0x02, 0x03, 0x01,
// Extension.extension_type (psk_key_exchange_modes)
0x00, 0x2D, 0x00, 0x02, 0x01, 0x01,
// Extension.extension_type (key_share)
0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1D,
0x00, 0x20,
0x04, 0x01, 0x04, 0x03, 0xEE, 0xEE, 0xED, 0xED,
0x03, 0x01, 0x03, 0x03, 0x02, 0x01, 0x02, 0x03,
0x04, 0x01, 0x04, 0x03, 0xEE, 0xEE, 0xED, 0xED,
0x03, 0x01, 0x03, 0x03, 0x02, 0x01, 0x02, 0x03
};

private static IEnumerable<byte[]> InvalidClientHello()
{
// This test covers following test cases:
Expand Down
5 changes: 5 additions & 0 deletions testassets/ReverseProxy.Direct/TlsFilter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ private static bool TryReadHello(ReadOnlySequence<byte> buffer, ILogger logger,
TlsFrameHelper.TlsFrameInfo info = default;
if (!TlsFrameHelper.TryGetFrameInfo(data, ref info))
{
if (info.ParsingStatus == TlsFrameHelper.ParsingStatus.InvalidFrame)
{
logger.LogInformation("Invalid TLS frame");
abort = true;
}
return false;
}

Expand Down

0 comments on commit 555ca0f

Please sign in to comment.