Skip to content

Commit

Permalink
Spawn provider updates into separate tokio tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
jakewmeyer committed Oct 27, 2024
1 parent 794118a commit 17f0cf9
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 72 deletions.
9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ categories = ["command-line-utilities", "network-programming"]
readme = "README.md"

[dependencies]
anyhow = "1.0.90"
anyhow = "1"
async-trait = "0.1.83"
clap = { version = "4.5.20", features = ["derive"] }
dyn-clone = "1.0.17"
hickory-resolver = "0.24.1"
humantime-serde = "1.1.1"
local-ip-address = "0.6.3"
Expand All @@ -26,11 +27,11 @@ reqwest = { version = "0.12.8", features = [
"brotli",
"rustls-tls",
], default-features = false }
serde = { version = "1.0.210", features = ["serde_derive"] }
serde_json = "1.0.129"
serde = { version = "1", features = ["serde_derive"] }
serde_json = "1"
smallvec = { version = "1.13.2", features = ["serde"] }
stun = "0.6.0"
tokio = { version = "1.40.0", features = [
tokio = { version = "1.41.0", features = [
"rt",
"rt-multi-thread",
"time",
Expand Down
36 changes: 28 additions & 8 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use core::fmt;
use dyn_clone::DynClone;
use hickory_resolver::config::{LookupIpStrategy, ResolverConfig, ResolverOpts};
use hickory_resolver::TokioAsyncResolver;
use local_ip_address::list_afinet_netifas;
Expand All @@ -10,7 +11,7 @@ use std::fmt::{Debug, Display, Formatter};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use tokio::task::{JoinHandle, JoinSet};
use tokio::time;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info};
Expand Down Expand Up @@ -47,7 +48,7 @@ pub enum IpSource {
}

/// Update sent to each provider
#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct IpUpdate {
v4: Option<IpAddr>,
v6: Option<IpAddr>,
Expand All @@ -73,10 +74,12 @@ impl IpUpdate {
/// Provider trait for updating DNS records or DDNS services
#[async_trait]
#[typetag::serde(tag = "type")]
pub trait Provider: Debug + Send + Sync {
async fn update(&self, update: &IpUpdate, request: &HttpClient) -> Result<bool>;
pub trait Provider: Debug + DynClone + Send + Sync {
async fn update(&self, update: IpUpdate, request: HttpClient) -> Result<bool>;
}

dyn_clone::clone_trait_object!(Provider);

/// DDRS client
#[derive(Debug)]
pub struct Client {
Expand Down Expand Up @@ -222,11 +225,28 @@ impl Client {
continue;
}
info!("IP address update detected, updating providers...");
let mut failed = false;

let mut set = JoinSet::new();
for provider in &self.config.providers {
if let Err(error) = provider.update(&update, &self.request).await {
error!("Failed to update provider: {error}");
failed = true;
let provider = provider.clone();
let update = update.clone();
let request = self.request.clone();
set.spawn(async move {
provider.update(update, request).await
});
}
let mut failed = false;
while let Some(result) = set.join_next().await {
match result {
Ok(result) => {
if let Err(error) = result {
error!("Failed to update provider: {error}");
failed = true;
}
},
Err(error) => {
error!("Provider task failed to complete: {error}");
}
}
}
if !failed {
Expand Down
108 changes: 48 additions & 60 deletions src/providers/cloudflare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,22 @@ use anyhow::{anyhow, Result};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::{json, Value};
use smallvec::SmallVec;
use tracing::{debug, error};

use crate::client::{IpUpdate, IpVersion, Provider};

const CLOUDFLARE_API: &str = "https://api.cloudflare.com/client/v4";

/// Cloudflare DNS update provider
#[derive(Debug, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Cloudflare {
zone: String,
api_token: String,
domains: SmallVec<[Domains; 2]>,
}

#[derive(Debug, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(default)]
struct Domains {
name: String,
Expand All @@ -30,7 +29,9 @@ struct Domains {
/// Zone lookup response
#[derive(Debug, Deserialize)]
struct ZoneList {
result: Vec<ZoneResult>,
result: Option<Vec<ZoneResult>>,
#[serde(rename = "errors")]
_errors: Vec<Value>,
}

#[derive(Debug, Deserialize)]
Expand All @@ -41,7 +42,9 @@ struct ZoneResult {
/// Records lookup response
#[derive(Debug, Deserialize)]
struct RecordsList {
result: Vec<RecordResult>,
result: Option<Vec<RecordResult>>,
#[serde(rename = "errors")]
_errors: Vec<Value>,
}

#[derive(Debug, Deserialize)]
Expand All @@ -52,25 +55,24 @@ struct RecordResult {
#[derive(Debug, Deserialize)]
struct UpdatedResult {
#[serde(rename = "errors")]
_errors: Vec<Option<serde_json::Value>>,
_errors: Vec<Option<Value>>,
#[serde(rename = "messages")]
_messages: Vec<Option<serde_json::Value>>,
_messages: Vec<Option<Value>>,
success: bool,
}

#[derive(Debug, Deserialize)]
struct CreatedResult {
#[serde(rename = "errors")]
_errors: Vec<Option<serde_json::Value>>,
#[serde(rename = "messages")]
_messages: Vec<Option<serde_json::Value>>,
_errors: Vec<Option<Value>>,
_messages: Vec<Option<Value>>,
success: bool,
}

#[async_trait]
#[typetag::serde(name = "cloudflare")]
impl Provider for Cloudflare {
async fn update(&self, update: &IpUpdate, request: &Client) -> Result<bool> {
async fn update(&self, update: IpUpdate, request: Client) -> Result<bool> {
let zones = request
.get(format!("{CLOUDFLARE_API}/zones"))
.query(&[("name", &self.zone)])
Expand All @@ -79,8 +81,13 @@ impl Provider for Cloudflare {
.await?
.json::<ZoneList>()
.await?;
let zone_id = &zones.result.first().ok_or(anyhow!("No zone found"))?.id;
debug!("Found zone ID: {}", zone_id);
let zone_result = zones
.result
.ok_or(anyhow!("Failed to list Cloudflare zones"))?;
let zone_id = &zone_result
.first()
.ok_or(anyhow!("Failed to find a matching Cloudflare zone"))?
.id;
for domain in &self.domains {
for (version, address) in update.as_array() {
if let Some(address) = address {
Expand All @@ -97,11 +104,7 @@ impl Provider for Cloudflare {
.await?
.json::<RecordsList>()
.await?;
if let Some(record) = records.result.first() {
debug!(
"Updating {:?} record for {} to {}",
version, domain.name, address
);
if let Some(record) = records.result.and_then(|vec| vec.into_iter().next()) {
let updated = request
.put(format!(
"{CLOUDFLARE_API}/zones/{zone_id}/dns_records/{0}",
Expand All @@ -120,51 +123,36 @@ impl Provider for Cloudflare {
.await?
.json::<UpdatedResult>()
.await?;
if updated.success {
debug!("Record updated: {:#?}", updated);
} else {
error!(
"Failed to update domain ({}) record: {:#?}",
domain.name, updated
);
return Err(anyhow!(
"Failed to update domain ({}) record",
domain.name
));
}
} else {
debug!(
"Creating {:?} record for {} to {}",
version, domain.name, address
);
let created = request
.post(format!("{CLOUDFLARE_API}/zones/{zone_id}/dns_records"))
.json(&json!({
"type": record_type,
"name": domain.name,
"content": address,
"ttl": domain.ttl,
"proxied": domain.proxied,
"comment": domain.comment,
}))
.bearer_auth(&self.api_token)
.send()
.await?
.json::<CreatedResult>()
.await?;
if created.success {
debug!("Record created: {:#?}", created);
} else {
error!(
"Failed to create domain ({}) record: {:#?}",
domain.name, created
);
if !updated.success {
return Err(anyhow!(
"Failed to create domain ({}) record",
"Failed to update Cloudflare domain ({}) record",
domain.name
));
}
};
return Ok(true);
}
let created = request
.post(format!("{CLOUDFLARE_API}/zones/{zone_id}/dns_records"))
.json(&json!({
"type": record_type,
"name": domain.name,
"content": address,
"ttl": domain.ttl,
"proxied": domain.proxied,
"comment": domain.comment,
}))
.bearer_auth(&self.api_token)
.send()
.await?
.json::<CreatedResult>()
.await?;
if !created.success {
return Err(anyhow!(
"Failed to create Cloudflare domain ({}) record",
domain.name
));
}
return Ok(true);
}
}
}
Expand Down

0 comments on commit 17f0cf9

Please sign in to comment.