Skip to content

Commit

Permalink
g3-hickory-client: use g3-socket and support bind to interface
Browse files Browse the repository at this point in the history
  • Loading branch information
zh-jq-b committed Jan 15, 2025
1 parent d669fbc commit 3f66f5e
Show file tree
Hide file tree
Showing 29 changed files with 290 additions and 136 deletions.
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions g3bench/src/module/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use anyhow::anyhow;
use clap::{value_parser, Arg, ArgMatches, Command};
use tokio::net::TcpStream;

use g3_hickory_client::{TcpConnectInfo, UdpConnectInfo};
use g3_socket::BindAddr;
#[cfg(any(target_os = "linux", target_os = "android"))]
use g3_types::net::InterfaceName;
Expand Down Expand Up @@ -61,6 +62,24 @@ impl SocketArgs {
.map_err(|e| anyhow!("failed to setup local udp socket: {e}"))
}

pub(crate) fn hickory_udp_connect_info(&self, server: SocketAddr) -> UdpConnectInfo {
UdpConnectInfo {
server,
bind: self.bind,
buf_conf: Default::default(),
misc_opts: Default::default(),
}
}

pub(crate) fn hickory_tcp_connect_info(&self, server: SocketAddr) -> TcpConnectInfo {
TcpConnectInfo {
server,
bind: self.bind,
keepalive: Default::default(),
misc_opts: Default::default(),
}
}

pub(crate) fn parse_args(&mut self, args: &ArgMatches) -> anyhow::Result<()> {
if let Some(ip) = args.get_one::<IpAddr>(SOCKET_ARG_LOCAL_ADDRESS) {
self.bind = BindAddr::Ip(*ip);
Expand Down
49 changes: 22 additions & 27 deletions g3bench/src/target/dns/opts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ use g3_types::net::{DnsEncryptionProtocol, RustlsClientConfigBuilder};

use super::{DnsRequest, DnsRequestPickState};
use crate::module::rustls::{AppendRustlsArgs, RustlsTlsClientArgs};
use crate::module::socket::{AppendSocketArgs, SocketArgs};

const DNS_ARG_TARGET: &str = "target";
const DNS_ARG_LOCAL_ADDRESS: &str = "local-address";
const DNS_ARG_TIMEOUT: &str = "timeout";
const DNS_ARG_CONNECT_TIMEOUT: &str = "connect-timeout";
const DNS_ARG_ENCRYPTION: &str = "encryption";
Expand Down Expand Up @@ -77,12 +77,14 @@ impl DnsRequestPickState for GlobalRequestPicker {

pub(super) struct BenchDnsArgs {
target: SocketAddr,
bind: Option<SocketAddr>,
encryption: Option<DnsEncryptionProtocol>,
use_tcp: bool,
pub(super) timeout: Duration,
pub(super) connect_timeout: Duration,

socket: SocketArgs,
tls: RustlsTlsClientArgs,

requests: Vec<DnsRequest>,
pub(super) dump_result: bool,
pub(super) iter_global: bool,
Expand All @@ -97,11 +99,11 @@ impl BenchDnsArgs {
};
BenchDnsArgs {
target,
bind: None,
encryption: None,
use_tcp: false,
timeout: Duration::from_secs(10),
connect_timeout: Duration::from_secs(10),
socket: SocketArgs::default(),
tls,
requests: Vec::new(),
dump_result: false,
Expand Down Expand Up @@ -158,8 +160,8 @@ impl BenchDnsArgs {

async fn new_dns_over_udp_client(&self) -> anyhow::Result<Client> {
// FIXME should we use random port?
let client_connect =
g3_hickory_client::io::udp::connect(self.target, self.bind, self.timeout);
let connect_info = self.socket.hickory_udp_connect_info(self.target);
let client_connect = g3_hickory_client::io::udp::connect(connect_info, self.timeout);

let (client, bg) = Client::connect(Box::pin(client_connect))
.await
Expand All @@ -171,9 +173,9 @@ impl BenchDnsArgs {
async fn new_dns_over_tcp_client(&self) -> anyhow::Result<Client> {
let (message_sender, outbound_messages) = BufDnsStreamHandle::new(self.target);

let connect_info = self.socket.hickory_tcp_connect_info(self.target);
let tcp_connect = g3_hickory_client::io::tcp::connect(
self.target,
self.bind,
connect_info,
outbound_messages,
self.connect_timeout,
);
Expand All @@ -194,9 +196,9 @@ impl BenchDnsArgs {
.tls_name
.clone()
.unwrap_or_else(|| ServerName::IpAddress(self.target.ip().into()));
let connect_info = self.socket.hickory_tcp_connect_info(self.target);
let tls_connect = g3_hickory_client::io::tls::connect(
self.target,
self.bind,
connect_info,
tls_client,
tls_name,
outbound_messages,
Expand All @@ -218,9 +220,9 @@ impl BenchDnsArgs {
.clone()
.unwrap_or_else(|| ServerName::IpAddress(self.target.ip().into()));

let connect_info = self.socket.hickory_tcp_connect_info(self.target);
let client_connect = g3_hickory_client::io::h2::connect(
self.target,
self.bind,
connect_info,
tls_client,
tls_name,
self.connect_timeout,
Expand All @@ -243,9 +245,9 @@ impl BenchDnsArgs {
None => self.target.ip().to_string(),
};

let connect_info = self.socket.hickory_udp_connect_info(self.target);
let client_connect = g3_hickory_client::io::h3::connect(
self.target,
self.bind,
connect_info,
tls_client,
tls_name,
self.connect_timeout,
Expand All @@ -268,9 +270,9 @@ impl BenchDnsArgs {
None => self.target.ip().to_string(),
};

let connect_info = self.socket.hickory_udp_connect_info(self.target);
let client_connect = g3_hickory_client::io::quic::connect(
self.target,
self.bind,
connect_info,
tls_client,
tls_name,
self.connect_timeout,
Expand All @@ -292,14 +294,6 @@ pub(super) fn add_dns_args(app: Command) -> Command {
.required(true)
.num_args(1),
)
.arg(
Arg::new(DNS_ARG_LOCAL_ADDRESS)
.value_name("LOCAL SOCKET ADDRESS")
.short('B')
.long(DNS_ARG_LOCAL_ADDRESS)
.num_args(1)
.value_parser(value_parser!(IpAddr)),
)
.arg(
Arg::new(DNS_ARG_TIMEOUT)
.value_name("TIMEOUT DURATION")
Expand Down Expand Up @@ -363,6 +357,7 @@ pub(super) fn add_dns_args(app: Command) -> Command {
.action(ArgAction::SetTrue)
.long(DNS_ARG_ITER_GLOBAL),
)
.append_socket_args()
.append_rustls_args()
}

Expand All @@ -378,10 +373,6 @@ pub(super) fn parse_dns_args(args: &ArgMatches) -> anyhow::Result<BenchDnsArgs>
return Err(anyhow!("invalid dns server address {target}"));
};

if let Some(ip) = args.get_one::<SocketAddr>(DNS_ARG_LOCAL_ADDRESS) {
dns_args.bind = Some(*ip);
}

if let Some(timeout) = g3_clap::humanize::get_duration(args, DNS_ARG_TIMEOUT)? {
dns_args.timeout = timeout;
}
Expand Down Expand Up @@ -429,6 +420,10 @@ pub(super) fn parse_dns_args(args: &ArgMatches) -> anyhow::Result<BenchDnsArgs>
dns_args.iter_global = true;
}

dns_args
.socket
.parse_args(args)
.context("invalid socket config")?;
dns_args
.tls
.parse_tls_args(args)
Expand Down
2 changes: 1 addition & 1 deletion g3proxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ lua53 = ["lua", "mlua/lua53"]
lua54 = ["lua", "mlua/lua54"]
python = ["pyo3"]
c-ares = ["g3-resolver/c-ares"]
hickory = ["g3-resolver/hickory"]
hickory = ["g3-resolver/hickory", "g3-slog-types/socket"]
quic = ["g3-daemon/quic", "g3-resolver/quic", "g3-yaml/quinn", "g3-types/quinn", "g3-dpi/quic", "dep:quinn"]
rustls-ring = ["g3-types/rustls-ring", "rustls/ring", "quinn?/rustls-ring"]
vendored-openssl = ["openssl/vendored", "openssl-probe"]
Expand Down
12 changes: 10 additions & 2 deletions g3proxy/src/config/resolver/hickory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use yaml_rust::{yaml, Yaml};

use g3_resolver::driver::hickory::HickoryDriverConfig;
use g3_resolver::{AnyResolveDriverConfig, ResolverRuntimeConfig};
use g3_socket::BindAddr;
use g3_types::metrics::NodeName;
use g3_yaml::YamlDocPosition;

Expand Down Expand Up @@ -59,8 +60,8 @@ impl HickoryResolverConfig {
}

#[inline]
pub(crate) fn get_bind_ip(&self) -> Option<IpAddr> {
self.driver.get_bind_ip()
pub(crate) fn get_bind_addr(&self) -> BindAddr {
self.driver.get_bind_addr()
}

#[inline]
Expand Down Expand Up @@ -142,6 +143,13 @@ impl HickoryResolverConfig {
self.driver.set_bind_ip(ip);
Ok(())
}
#[cfg(any(target_os = "linux", target_os = "android"))]
"bind_interface" => {
let interface = g3_yaml::value::as_interface_name(v)
.context(format!("invalid interface name value for key {k}"))?;
self.driver.set_bind_interface(interface);
Ok(())
}
"positive_min_ttl" => {
let ttl = g3_yaml::value::as_u32(v)?;
self.driver.set_positive_min_ttl(ttl);
Expand Down
4 changes: 2 additions & 2 deletions g3proxy/src/resolve/hickory/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use slog::{slog_info, Logger};
use tokio::time::Instant;

use g3_resolver::{ResolveError, ResolveQueryType, ResolvedRecordSource};
use g3_slog_types::{LtDuration, LtIpAddr};
use g3_slog_types::{LtBindAddr, LtDuration};
use g3_types::metrics::NodeName;

use crate::config::resolver::hickory::HickoryResolverConfig;
Expand Down Expand Up @@ -106,7 +106,7 @@ impl LoggedResolveJob for HickoryResolverJob {
.collect::<Vec<_>>()
.join(" ");
slog_info!(&self.logger, "{}", e; // TODO add encryption info
"bind_ip" => self.config.get_bind_ip().map(LtIpAddr),
"bind_addr" => LtBindAddr(self.config.get_bind_addr()),
"server" => servers,
"server_port" => self.config.get_server_port(),
"encryption" => self.config.get_encryption_summary(),
Expand Down
2 changes: 2 additions & 0 deletions lib/g3-hickory-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ tokio-rustls.workspace = true
quinn = { workspace = true, optional = true, features = ["rustls"] }
h3 = { workspace = true, optional = true }
h3-quinn = { workspace = true, optional = true }
g3-socket.workspace = true
g3-types.workspace = true

[features]
default = []
Expand Down
3 changes: 3 additions & 0 deletions lib/g3-hickory-client/src/connect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
*/

pub(crate) mod tcp;
pub use tcp::TcpConnectInfo;

pub(crate) mod udp;
pub use udp::UdpConnectInfo;

pub(crate) mod rustls;

Expand Down
10 changes: 5 additions & 5 deletions lib/g3-hickory-client/src/connect/quinn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,22 @@
* limitations under the License.
*/

use std::net::SocketAddr;
use std::sync::Arc;

use hickory_proto::ProtoError;
use quinn::crypto::rustls::QuicClientConfig;
use quinn::{Connection, Endpoint, EndpointConfig, TokioRuntime};
use rustls::ClientConfig;

use super::UdpConnectInfo;

pub(crate) async fn quic_connect(
name_server: SocketAddr,
bind_addr: Option<SocketAddr>,
connect_info: UdpConnectInfo,
mut tls_config: ClientConfig,
tls_name: &str,
alpn_protocol: &'static [u8],
) -> Result<Connection, ProtoError> {
let sock = super::udp::udp_connect(name_server, bind_addr)?;
let sock = connect_info.udp_connect()?;

let endpoint_config = EndpointConfig::default(); // TODO set max payload size
let mut endpoint = Endpoint::new(endpoint_config, None, sock, Arc::new(TokioRuntime))?;
Expand All @@ -44,7 +44,7 @@ pub(crate) async fn quic_connect(
endpoint.set_default_client_config(client_config);

let connection = endpoint
.connect(name_server, tls_name)
.connect(connect_info.server, tls_name)
.map_err(|e| format!("quinn endpoint create error: {e}"))?
.await
.map_err(|e| format!("quinn endpoint connect error: {e}"))?;
Expand Down
8 changes: 4 additions & 4 deletions lib/g3-hickory-client/src/connect/rustls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
* limitations under the License.
*/

use std::net::SocketAddr;
use std::sync::Arc;

use hickory_proto::ProtoError;
Expand All @@ -24,14 +23,15 @@ use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
use tokio_rustls::TlsConnector;

use super::TcpConnectInfo;

pub(crate) async fn tls_connect(
name_server: SocketAddr,
bind_addr: Option<SocketAddr>,
connect_info: &TcpConnectInfo,
mut tls_config: ClientConfig,
tls_name: ServerName<'static>,
alpn_protocol: &'static [u8],
) -> Result<TlsStream<TcpStream>, ProtoError> {
let tcp_stream = super::tcp::tcp_connect(name_server, bind_addr).await?;
let tcp_stream = connect_info.tcp_connect().await?;

if tls_config.alpn_protocols.is_empty() {
tls_config.alpn_protocols = vec![alpn_protocol.to_vec()];
Expand Down
36 changes: 22 additions & 14 deletions lib/g3-hickory-client/src/connect/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,28 @@
use std::net::SocketAddr;

use hickory_proto::ProtoError;
use tokio::net::{TcpSocket, TcpStream};
use tokio::net::TcpStream;

pub(crate) async fn tcp_connect(
name_server: SocketAddr,
bind_addr: Option<SocketAddr>,
) -> Result<TcpStream, ProtoError> {
let socket = match name_server {
SocketAddr::V4(_) => TcpSocket::new_v4(),
SocketAddr::V6(_) => TcpSocket::new_v6(),
}?;
if let Some(addr) = bind_addr {
socket.bind(addr)?;
}
use g3_socket::BindAddr;
use g3_types::net::{TcpKeepAliveConfig, TcpMiscSockOpts};

pub struct TcpConnectInfo {
pub server: SocketAddr,
pub bind: BindAddr,
pub keepalive: TcpKeepAliveConfig,
pub misc_opts: TcpMiscSockOpts,
}

let tcp_stream = socket.connect(name_server).await?;
Ok(tcp_stream)
impl TcpConnectInfo {
pub(crate) async fn tcp_connect(&self) -> Result<TcpStream, ProtoError> {
let socket = g3_socket::tcp::new_socket_to(
self.server.ip(),
&self.bind,
&self.keepalive,
&self.misc_opts,
true,
)?;
let tcp_stream = socket.connect(self.server).await?;
Ok(tcp_stream)
}
}
Loading

0 comments on commit 3f66f5e

Please sign in to comment.