Skip to content

Commit

Permalink
Merge pull request #15024 from rgacogne/ddist-quic-sni
Browse files Browse the repository at this point in the history
dnsdist: Gather Server Name Indication on QUIC (DoQ, DoH3) connections
  • Loading branch information
rgacogne authored Jan 23, 2025
2 parents abe6d42 + febf9b5 commit 2d3e3ea
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 10 deletions.
19 changes: 16 additions & 3 deletions pdns/dnsdistdist/doh3.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ class H3Connection
H3Connection& operator=(H3Connection&&) = default;
~H3Connection() = default;

std::shared_ptr<const std::string> getSNI()
{
if (!d_sni) {
d_sni = std::make_shared<const std::string>(getSNIFromQuicheConnection(d_conn));
}
return d_sni;
}

ComboAddress d_peer;
ComboAddress d_localAddr;
QuicheConnection d_conn;
Expand All @@ -71,6 +79,7 @@ class H3Connection
std::unordered_map<uint64_t, dnsdist::doh3::h3_headers_t> d_headersBuffers;
std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
std::unordered_map<uint64_t, PacketBuffer> d_streamOutBuffers;
std::shared_ptr<const std::string> d_sni{nullptr};
};

static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description);
Expand Down Expand Up @@ -566,6 +575,9 @@ static void processDOH3Query(DOH3UnitUniquePtr&& doh3Unit)
ids.origFlags = *flags;
return true;
});
if (unit->sni) {
dnsQuestion.sni = *unit->sni;
}
unit->ids.cs = &clientState;

auto result = processQuery(dnsQuestion, downstream);
Expand Down Expand Up @@ -640,7 +652,7 @@ static void processDOH3Query(DOH3UnitUniquePtr&& doh3Unit)
}
}

static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID, dnsdist::doh3::h3_headers_t&& headers)
static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID, const std::shared_ptr<const std::string>& sni, dnsdist::doh3::h3_headers_t&& headers)
{
try {
auto unit = std::make_unique<DOH3Unit>(std::move(query));
Expand All @@ -650,6 +662,7 @@ static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, con
unit->ids.protocol = dnsdist::Protocol::DoH3;
unit->serverConnID = serverConnID;
unit->streamID = streamID;
unit->sni = sni;
unit->headers = std::move(headers);

processDOH3Query(std::move(unit));
Expand Down Expand Up @@ -751,7 +764,7 @@ static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& fronten
return;
}
DEBUGLOG("Dispatching GET query");
doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID, std::move(headers));
doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID, conn.getSNI(), std::move(headers));
conn.d_streamBuffers.erase(streamID);
conn.d_headersBuffers.erase(streamID);
return;
Expand Down Expand Up @@ -816,7 +829,7 @@ static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend,
}

DEBUGLOG("Dispatching POST query");
doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID, std::move(headers));
doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID, conn.getSNI(), std::move(headers));
conn.d_headersBuffers.erase(streamID);
conn.d_streamBuffers.erase(streamID);
}
Expand Down
1 change: 1 addition & 0 deletions pdns/dnsdistdist/doh3.hh
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ struct DOH3Unit
PacketBuffer serverConnID;
dnsdist::doh3::h3_headers_t headers;
std::shared_ptr<DownstreamState> downstream{nullptr};
std::shared_ptr<const std::string> sni{nullptr};
std::string d_contentTypeOut;
DOH3ServerConfig* dsc{nullptr};
uint64_t streamID{0};
Expand Down
14 changes: 13 additions & 1 deletion pdns/dnsdistdist/doq-common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,18 @@ bool recvAsync(Socket& socket, PacketBuffer& buffer, ComboAddress& clientAddr, C
return !buffer.empty();
}

};
std::string getSNIFromQuicheConnection(const QuicheConnection& conn)
{
#if defined(HAVE_QUICHE_CONN_SERVER_NAME)
const uint8_t* sniPtr = nullptr;
size_t sniPtrSize = 0;
quiche_conn_server_name(conn.get(), &sniPtr, &sniPtrSize);
if (sniPtrSize > 0) {
return std::string(reinterpret_cast<const char*>(sniPtr), sniPtrSize);
}
#endif /* HAVE_QUICHE_CONN_SERVER_NAME */
return {};
}
}

#endif
3 changes: 2 additions & 1 deletion pdns/dnsdistdist/doq-common.hh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include <map>
#include <memory>
#include <string>

#include "config.h"

Expand Down Expand Up @@ -97,7 +98,7 @@ void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, co
void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer, const ComboAddress& localAddr, PacketBuffer& buffer);
void configureQuiche(QuicheConfig& config, const QuicheParams& params, bool isHTTP);
bool recvAsync(Socket& socket, PacketBuffer& buffer, ComboAddress& clientAddr, ComboAddress& localAddr);

std::string getSNIFromQuicheConnection(const QuicheConnection& conn);
};

#endif
17 changes: 15 additions & 2 deletions pdns/dnsdistdist/doq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,22 @@ class Connection
Connection& operator=(Connection&&) = default;
~Connection() = default;

std::shared_ptr<const std::string> getSNI()
{
if (!d_sni) {
d_sni = std::make_shared<const std::string>(getSNIFromQuicheConnection(d_conn));
}
return d_sni;
}

ComboAddress d_peer;
ComboAddress d_localAddr;
QuicheConnection d_conn;
QuicheConfig d_config;

std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
std::unordered_map<uint64_t, PacketBuffer> d_streamOutBuffers;
std::shared_ptr<const std::string> d_sni{nullptr};
};

static void sendBackDOQUnit(DOQUnitUniquePtr&& unit, const char* description);
Expand Down Expand Up @@ -472,6 +481,9 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit)
ids.origFlags = *flags;
return true;
});
if (unit->sni) {
dnsQuestion.sni = *unit->sni;
}
unit->ids.cs = &clientState;

auto result = processQuery(dnsQuestion, downstream);
Expand Down Expand Up @@ -541,7 +553,7 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit)
}
}

static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID)
static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID, const std::shared_ptr<const std::string>& sni)
{
try {
auto unit = std::make_unique<DOQUnit>(std::move(query));
Expand All @@ -551,6 +563,7 @@ static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const
unit->ids.protocol = dnsdist::Protocol::DoQ;
unit->serverConnID = serverConnID;
unit->streamID = streamID;
unit->sni = sni;

processDOQQuery(std::move(unit));
}
Expand Down Expand Up @@ -649,7 +662,7 @@ static void handleReadableStream(DOQFrontend& frontend, ClientState& clientState
return;
}
DEBUGLOG("Dispatching query");
doq_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID);
doq_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID, conn.getSNI());
conn.d_streamBuffers.erase(streamID);
}

Expand Down
1 change: 1 addition & 0 deletions pdns/dnsdistdist/doq.hh
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ struct DOQUnit
PacketBuffer response;
PacketBuffer serverConnID;
std::shared_ptr<DownstreamState> downstream{nullptr};
std::shared_ptr<const std::string> sni{nullptr};
DOQServerConfig* dsc{nullptr};
uint64_t streamID{0};
size_t proxyProtocolPayloadSize{0};
Expand Down
10 changes: 10 additions & 0 deletions pdns/dnsdistdist/m4/pdns_with_quiche.m4
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ AC_DEFUN([PDNS_WITH_QUICHE], [
AC_DEFINE([HAVE_QUICHE], [1], [Define to 1 if you have quiche])
], [ : ])
])
AS_IF([test "x$HAVE_QUICHE" = "x1"], [
save_CFLAGS=$CFLAGS
save_LIBS=$LIBS
CFLAGS="$QUICHE_CFLAGS $CFLAGS"
LIBS="$QUICHE_LIBS $LIBS"
AC_CHECK_FUNCS([quiche_conn_server_name])
CFLAGS=$save_CFLAGS
LIBS=$save_LIBS
])
])
])
AM_CONDITIONAL([HAVE_QUICHE], [test "x$QUICHE_LIBS" != "x"])
Expand Down
2 changes: 1 addition & 1 deletion regression-tests.dnsdist/doh3client.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ async def async_h3_query(


def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False, additional_headers=None, raw_response=False):
configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True)
configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True, server_name=server_hostname)
if verify:
configuration.load_verify_locations(verify)

Expand Down
4 changes: 2 additions & 2 deletions regression-tests.dnsdist/doqclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(self, error, message="Stream reset by peer"):
super().__init__(message)

def quic_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True, server_name=server_hostname)
if verify:
configuration.load_verify_locations(verify)
(result, serial) = asyncio.run(
Expand All @@ -108,7 +108,7 @@ def quic_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server
return (result, serial)

def quic_bogus_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True, server_name=server_hostname)
if verify:
configuration.load_verify_locations(verify)
(result, _) = asyncio.run(
Expand Down
82 changes: 82 additions & 0 deletions regression-tests.dnsdist/test_SNI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env python
import base64
import dns
import os
import unittest
import pycurl

from dnsdisttests import DNSDistTest, pickAvailablePort

class TestSNI(DNSDistTest):
_serverKey = 'server.key'
_serverCert = 'server.chain'
_serverName = 'tls.tests.dnsdist.org'
_caCert = 'ca.pem'
_tlsServerPort = pickAvailablePort()
_dohWithNGHTTP2ServerPort = pickAvailablePort()
_doqServerPort = pickAvailablePort()
_doh3ServerPort = pickAvailablePort()
_dohWithNGHTTP2BaseURL = ("https://%s:%d/" % (_serverName, _dohWithNGHTTP2ServerPort))
_dohBaseURL = ("https://%s:%d/" % (_serverName, _doh3ServerPort))

_config_template = """
newServer{address="127.0.0.1:%d"}
addTLSLocal("127.0.0.1:%d", "%s", "%s", { provider="openssl" })
addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library="nghttp2"})
addDOQLocal("127.0.0.1:%d", "%s", "%s")
addDOH3Local("127.0.0.1:%d", "%s", "%s")
function displaySNI(dq)
local sni = dq:getServerNameIndication()
if sni ~= '%s' then
return DNSAction.Spoof, '1.2.3.4'
end
return DNSAction.Allow
end
addAction(AllRule(), LuaAction(displaySNI))
"""
_config_params = ['_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey', '_dohWithNGHTTP2ServerPort', '_serverCert', '_serverKey', '_doqServerPort', '_serverCert', '_serverKey', '_doh3ServerPort', '_serverCert', '_serverKey', '_serverName']

# enable these once Quiche > 0.22 is available, including https://github.com/cloudflare/quiche/pull/1895
@unittest.skipUnless('ENABLE_SNI_TESTS_WITH_QUICHE' in os.environ, "SNI tests with Quicheare disabled")
def testServerNameIndicationWithQuiche(self):
name = 'simple.sni.tests.powerdns.com.'
query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
response = dns.message.make_response(query)
rrset = dns.rrset.from_text(name,
3600,
dns.rdataclass.IN,
dns.rdatatype.A,
'127.0.0.1')
response.answer.append(rrset)
for method in ["sendDOQQueryWrapper", "sendDOH3QueryWrapper"]:
sender = getattr(self, method)
(receivedQuery, receivedResponse) = sender(query, response, timeout=1)
self.assertTrue(receivedQuery)
receivedQuery.id = query.id
self.assertEqual(query, receivedQuery)
self.assertTrue(receivedResponse)
if method == 'sendDOQQueryWrapper':
# dnspython sets the ID to 0
receivedResponse.id = response.id
self.assertEqual(response, receivedResponse)

def testServerNameIndication(self):
name = 'simple.sni.tests.powerdns.com.'
query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
response = dns.message.make_response(query)
rrset = dns.rrset.from_text(name,
3600,
dns.rdataclass.IN,
dns.rdatatype.A,
'127.0.0.1')
response.answer.append(rrset)
for method in ["sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper"]:
sender = getattr(self, method)
(receivedQuery, receivedResponse) = sender(query, response, timeout=1)
self.assertTrue(receivedQuery)
receivedQuery.id = query.id
self.assertEqual(query, receivedQuery)
self.assertTrue(receivedResponse)
self.assertEqual(response, receivedResponse)

0 comments on commit 2d3e3ea

Please sign in to comment.