From febf9b5924081f8089e15be846374cb91ca28ac6 Mon Sep 17 00:00:00 2001 From: Remi Gacogne Date: Thu, 9 Jan 2025 15:36:49 +0100 Subject: [PATCH] dnsdist: Gather Server Name Indication on QUIC (DoQ, DoH3) connections --- pdns/dnsdistdist/doh3.cc | 19 +++++- pdns/dnsdistdist/doh3.hh | 1 + pdns/dnsdistdist/doq-common.cc | 14 ++++- pdns/dnsdistdist/doq-common.hh | 3 +- pdns/dnsdistdist/doq.cc | 17 ++++- pdns/dnsdistdist/doq.hh | 1 + pdns/dnsdistdist/m4/pdns_with_quiche.m4 | 10 +++ regression-tests.dnsdist/doh3client.py | 2 +- regression-tests.dnsdist/doqclient.py | 4 +- regression-tests.dnsdist/test_SNI.py | 82 +++++++++++++++++++++++++ 10 files changed, 143 insertions(+), 10 deletions(-) create mode 100644 regression-tests.dnsdist/test_SNI.py diff --git a/pdns/dnsdistdist/doh3.cc b/pdns/dnsdistdist/doh3.cc index 6ee0db58b994..edd08934d1d1 100644 --- a/pdns/dnsdistdist/doh3.cc +++ b/pdns/dnsdistdist/doh3.cc @@ -62,6 +62,14 @@ class H3Connection H3Connection& operator=(H3Connection&&) = default; ~H3Connection() = default; + std::shared_ptr getSNI() + { + if (!d_sni) { + d_sni = std::make_shared(getSNIFromQuicheConnection(d_conn)); + } + return d_sni; + } + ComboAddress d_peer; ComboAddress d_localAddr; QuicheConnection d_conn; @@ -71,6 +79,7 @@ class H3Connection std::unordered_map d_headersBuffers; std::unordered_map d_streamBuffers; std::unordered_map d_streamOutBuffers; + std::shared_ptr d_sni{nullptr}; }; static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description); @@ -566,6 +575,9 @@ static void processDOH3Query(DOH3UnitUniquePtr&& doh3Unit) ids.origFlags = *flags; return true; }); + if (unit->sni) { + dnsQuestion.sni = *unit->sni; + } unit->ids.cs = &clientState; auto result = processQuery(dnsQuestion, downstream); @@ -640,7 +652,7 @@ static void processDOH3Query(DOH3UnitUniquePtr&& doh3Unit) } } -static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID, dnsdist::doh3::h3_headers_t&& headers) +static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID, const std::shared_ptr& sni, dnsdist::doh3::h3_headers_t&& headers) { try { auto unit = std::make_unique(std::move(query)); @@ -650,6 +662,7 @@ static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, con unit->ids.protocol = dnsdist::Protocol::DoH3; unit->serverConnID = serverConnID; unit->streamID = streamID; + unit->sni = sni; unit->headers = std::move(headers); processDOH3Query(std::move(unit)); @@ -751,7 +764,7 @@ static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& fronten return; } DEBUGLOG("Dispatching GET query"); - doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID, std::move(headers)); + doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID, conn.getSNI(), std::move(headers)); conn.d_streamBuffers.erase(streamID); conn.d_headersBuffers.erase(streamID); return; @@ -816,7 +829,7 @@ static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend, } DEBUGLOG("Dispatching POST query"); - doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID, std::move(headers)); + doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID, conn.getSNI(), std::move(headers)); conn.d_headersBuffers.erase(streamID); conn.d_streamBuffers.erase(streamID); } diff --git a/pdns/dnsdistdist/doh3.hh b/pdns/dnsdistdist/doh3.hh index 97a52a2c4202..9e021c6114db 100644 --- a/pdns/dnsdistdist/doh3.hh +++ b/pdns/dnsdistdist/doh3.hh @@ -101,6 +101,7 @@ struct DOH3Unit PacketBuffer serverConnID; dnsdist::doh3::h3_headers_t headers; std::shared_ptr downstream{nullptr}; + std::shared_ptr sni{nullptr}; std::string d_contentTypeOut; DOH3ServerConfig* dsc{nullptr}; uint64_t streamID{0}; diff --git a/pdns/dnsdistdist/doq-common.cc b/pdns/dnsdistdist/doq-common.cc index bb79ddc21849..ea6145476131 100644 --- a/pdns/dnsdistdist/doq-common.cc +++ b/pdns/dnsdistdist/doq-common.cc @@ -325,6 +325,18 @@ bool recvAsync(Socket& socket, PacketBuffer& buffer, ComboAddress& clientAddr, C return !buffer.empty(); } -}; +std::string getSNIFromQuicheConnection(const QuicheConnection& conn) +{ +#if defined(HAVE_QUICHE_CONN_SERVER_NAME) + const uint8_t* sniPtr = nullptr; + size_t sniPtrSize = 0; + quiche_conn_server_name(conn.get(), &sniPtr, &sniPtrSize); + if (sniPtrSize > 0) { + return std::string(reinterpret_cast(sniPtr), sniPtrSize); + } +#endif /* HAVE_QUICHE_CONN_SERVER_NAME */ + return {}; +} +} #endif diff --git a/pdns/dnsdistdist/doq-common.hh b/pdns/dnsdistdist/doq-common.hh index 9b04e4c83581..43c6bd9c55db 100644 --- a/pdns/dnsdistdist/doq-common.hh +++ b/pdns/dnsdistdist/doq-common.hh @@ -23,6 +23,7 @@ #include #include +#include #include "config.h" @@ -97,7 +98,7 @@ void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, co void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer, const ComboAddress& localAddr, PacketBuffer& buffer); void configureQuiche(QuicheConfig& config, const QuicheParams& params, bool isHTTP); bool recvAsync(Socket& socket, PacketBuffer& buffer, ComboAddress& clientAddr, ComboAddress& localAddr); - +std::string getSNIFromQuicheConnection(const QuicheConnection& conn); }; #endif diff --git a/pdns/dnsdistdist/doq.cc b/pdns/dnsdistdist/doq.cc index 661fe5c2b5f9..73536cffec45 100644 --- a/pdns/dnsdistdist/doq.cc +++ b/pdns/dnsdistdist/doq.cc @@ -61,6 +61,14 @@ class Connection Connection& operator=(Connection&&) = default; ~Connection() = default; + std::shared_ptr getSNI() + { + if (!d_sni) { + d_sni = std::make_shared(getSNIFromQuicheConnection(d_conn)); + } + return d_sni; + } + ComboAddress d_peer; ComboAddress d_localAddr; QuicheConnection d_conn; @@ -68,6 +76,7 @@ class Connection std::unordered_map d_streamBuffers; std::unordered_map d_streamOutBuffers; + std::shared_ptr d_sni{nullptr}; }; static void sendBackDOQUnit(DOQUnitUniquePtr&& unit, const char* description); @@ -472,6 +481,9 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit) ids.origFlags = *flags; return true; }); + if (unit->sni) { + dnsQuestion.sni = *unit->sni; + } unit->ids.cs = &clientState; auto result = processQuery(dnsQuestion, downstream); @@ -541,7 +553,7 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit) } } -static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID) +static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID, const std::shared_ptr& sni) { try { auto unit = std::make_unique(std::move(query)); @@ -551,6 +563,7 @@ static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const unit->ids.protocol = dnsdist::Protocol::DoQ; unit->serverConnID = serverConnID; unit->streamID = streamID; + unit->sni = sni; processDOQQuery(std::move(unit)); } @@ -649,7 +662,7 @@ static void handleReadableStream(DOQFrontend& frontend, ClientState& clientState return; } DEBUGLOG("Dispatching query"); - doq_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID); + doq_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID, conn.getSNI()); conn.d_streamBuffers.erase(streamID); } diff --git a/pdns/dnsdistdist/doq.hh b/pdns/dnsdistdist/doq.hh index 258194177a6e..4b0d2dc550a6 100644 --- a/pdns/dnsdistdist/doq.hh +++ b/pdns/dnsdistdist/doq.hh @@ -84,6 +84,7 @@ struct DOQUnit PacketBuffer response; PacketBuffer serverConnID; std::shared_ptr downstream{nullptr}; + std::shared_ptr sni{nullptr}; DOQServerConfig* dsc{nullptr}; uint64_t streamID{0}; size_t proxyProtocolPayloadSize{0}; diff --git a/pdns/dnsdistdist/m4/pdns_with_quiche.m4 b/pdns/dnsdistdist/m4/pdns_with_quiche.m4 index 672fe0f79f22..784c9cb869ee 100644 --- a/pdns/dnsdistdist/m4/pdns_with_quiche.m4 +++ b/pdns/dnsdistdist/m4/pdns_with_quiche.m4 @@ -21,6 +21,16 @@ AC_DEFUN([PDNS_WITH_QUICHE], [ AC_DEFINE([HAVE_QUICHE], [1], [Define to 1 if you have quiche]) ], [ : ]) ]) + AS_IF([test "x$HAVE_QUICHE" = "x1"], [ + save_CFLAGS=$CFLAGS + save_LIBS=$LIBS + CFLAGS="$QUICHE_CFLAGS $CFLAGS" + LIBS="$QUICHE_LIBS $LIBS" + AC_CHECK_FUNCS([quiche_conn_server_name]) + CFLAGS=$save_CFLAGS + LIBS=$save_LIBS + + ]) ]) ]) AM_CONDITIONAL([HAVE_QUICHE], [test "x$QUICHE_LIBS" != "x"]) diff --git a/regression-tests.dnsdist/doh3client.py b/regression-tests.dnsdist/doh3client.py index 953f5befa0ab..e2e2bc107ad6 100644 --- a/regression-tests.dnsdist/doh3client.py +++ b/regression-tests.dnsdist/doh3client.py @@ -228,7 +228,7 @@ async def async_h3_query( def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False, additional_headers=None, raw_response=False): - configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True) + configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True, server_name=server_hostname) if verify: configuration.load_verify_locations(verify) diff --git a/regression-tests.dnsdist/doqclient.py b/regression-tests.dnsdist/doqclient.py index 2f0272630f3b..7fa416d237c5 100644 --- a/regression-tests.dnsdist/doqclient.py +++ b/regression-tests.dnsdist/doqclient.py @@ -88,7 +88,7 @@ def __init__(self, error, message="Stream reset by peer"): super().__init__(message) def quic_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None): - configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True) + configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True, server_name=server_hostname) if verify: configuration.load_verify_locations(verify) (result, serial) = asyncio.run( @@ -108,7 +108,7 @@ def quic_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server return (result, serial) def quic_bogus_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None): - configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True) + configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True, server_name=server_hostname) if verify: configuration.load_verify_locations(verify) (result, _) = asyncio.run( diff --git a/regression-tests.dnsdist/test_SNI.py b/regression-tests.dnsdist/test_SNI.py new file mode 100644 index 000000000000..ac2f221040d9 --- /dev/null +++ b/regression-tests.dnsdist/test_SNI.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +import base64 +import dns +import os +import unittest +import pycurl + +from dnsdisttests import DNSDistTest, pickAvailablePort + +class TestSNI(DNSDistTest): + _serverKey = 'server.key' + _serverCert = 'server.chain' + _serverName = 'tls.tests.dnsdist.org' + _caCert = 'ca.pem' + _tlsServerPort = pickAvailablePort() + _dohWithNGHTTP2ServerPort = pickAvailablePort() + _doqServerPort = pickAvailablePort() + _doh3ServerPort = pickAvailablePort() + _dohWithNGHTTP2BaseURL = ("https://%s:%d/" % (_serverName, _dohWithNGHTTP2ServerPort)) + _dohBaseURL = ("https://%s:%d/" % (_serverName, _doh3ServerPort)) + + _config_template = """ + newServer{address="127.0.0.1:%d"} + + addTLSLocal("127.0.0.1:%d", "%s", "%s", { provider="openssl" }) + addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library="nghttp2"}) + addDOQLocal("127.0.0.1:%d", "%s", "%s") + addDOH3Local("127.0.0.1:%d", "%s", "%s") + + function displaySNI(dq) + local sni = dq:getServerNameIndication() + if sni ~= '%s' then + return DNSAction.Spoof, '1.2.3.4' + end + return DNSAction.Allow + end + addAction(AllRule(), LuaAction(displaySNI)) + """ + _config_params = ['_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey', '_dohWithNGHTTP2ServerPort', '_serverCert', '_serverKey', '_doqServerPort', '_serverCert', '_serverKey', '_doh3ServerPort', '_serverCert', '_serverKey', '_serverName'] + + # enable these once Quiche > 0.22 is available, including https://github.com/cloudflare/quiche/pull/1895 + @unittest.skipUnless('ENABLE_SNI_TESTS_WITH_QUICHE' in os.environ, "SNI tests with Quicheare disabled") + def testServerNameIndicationWithQuiche(self): + name = 'simple.sni.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + for method in ["sendDOQQueryWrapper", "sendDOH3QueryWrapper"]: + sender = getattr(self, method) + (receivedQuery, receivedResponse) = sender(query, response, timeout=1) + self.assertTrue(receivedQuery) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertTrue(receivedResponse) + if method == 'sendDOQQueryWrapper': + # dnspython sets the ID to 0 + receivedResponse.id = response.id + self.assertEqual(response, receivedResponse) + + def testServerNameIndication(self): + name = 'simple.sni.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + for method in ["sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper"]: + sender = getattr(self, method) + (receivedQuery, receivedResponse) = sender(query, response, timeout=1) + self.assertTrue(receivedQuery) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertTrue(receivedResponse) + self.assertEqual(response, receivedResponse)