Skip to content

Commit

Permalink
Refactor Cloudflare provider into separate functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jakewmeyer committed Dec 16, 2024
1 parent 588ea4e commit 29f6337
Showing 1 changed file with 157 additions and 98 deletions.
255 changes: 157 additions & 98 deletions src/providers/cloudflare.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::net::IpAddr;

use anyhow::{anyhow, Result};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use serde_json::json;
use smallvec::SmallVec;

use crate::client::{IpUpdate, IpVersion, Provider};
Expand All @@ -12,7 +14,7 @@ use crate::client::{IpUpdate, IpVersion, Provider};
pub struct Cloudflare {
zone: String,
api_token: String,
domains: SmallVec<[Domains; 2]>,
domains: SmallVec<[Domain; 2]>,
#[serde(default = "default_api_url")]
api_url: String,
}
Expand All @@ -23,32 +25,38 @@ fn default_api_url() -> String {

#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(default)]
struct Domains {
struct Domain {
name: String,
#[serde(default = "default_ttl")]
ttl: u32,
#[serde(default)]
proxied: bool,
comment: Option<String>,
#[serde(default = "default_comment")]
comment: String,
}

// TTL of 1 is Cloudflare's auto setting
fn default_ttl() -> u32 {
1
}

fn default_comment() -> String {
String::from("Created by DDRS")
}

/// Zone lookup response
#[derive(Debug, Deserialize)]
struct ZoneList {
result: Option<Vec<ZoneResult>>,
#[serde(rename = "errors")]
_errors: Vec<Value>,
}

#[derive(Debug, Deserialize)]
struct ZoneResult {
id: String,
}

/// Records lookup response
#[derive(Debug, Deserialize)]
struct RecordsList {
result: Option<Vec<RecordResult>>,
#[serde(rename = "errors")]
_errors: Vec<Value>,
}

#[derive(Debug, Deserialize)]
Expand All @@ -58,26 +66,16 @@ struct RecordResult {

#[derive(Debug, Deserialize)]
struct UpdatedResult {
#[serde(rename = "errors")]
_errors: Vec<Value>,
#[serde(rename = "messages")]
_messages: Vec<Value>,
success: bool,
}

#[derive(Debug, Deserialize)]
struct CreatedResult {
#[serde(rename = "errors")]
_errors: Vec<Value>,
#[serde(rename = "messages")]
_messages: Vec<Value>,
success: bool,
}

#[async_trait]
#[typetag::serde(name = "cloudflare")]
impl Provider for Cloudflare {
async fn update(&self, update: IpUpdate, request: Client) -> Result<bool> {
impl Cloudflare {
async fn fetch_zone_id(&self, request: &Client) -> Result<String> {
let zones = request
.get(format!("{}/zones", self.api_url))
.query(&[("name", &self.zone)])
Expand All @@ -89,75 +87,136 @@ impl Provider for Cloudflare {
let zone_result = zones.result.ok_or(anyhow!(
"Failed to list Cloudflare zones, is your token valid?"
))?;
let zone_id = &zone_result
Ok(zone_result
.first()
.ok_or(anyhow!("Failed to find a matching Cloudflare zone"))?
.id;
.id
.clone())
}

async fn fetch_dns_records(
&self,
request: &Client,
zone_id: &str,
record_type: &str,
domain: &Domain,
) -> Result<Vec<RecordResult>> {
let records = request
.get(format!("{}/zones/{}/dns_records", self.api_url, zone_id))
.query(&[("name", &domain.name)])
.query(&[("type", record_type)])
.bearer_auth(&self.api_token)
.send()
.await?
.json::<RecordsList>()
.await?;
records.result.ok_or(anyhow!(
"Failed to list Cloudflare DNS records for {}",
domain.name
))
}

async fn update_dns_record(
&self,
request: &Client,
zone_id: &str,
record_id: &str,
record_type: &str,
domain: &Domain,
address: &IpAddr,
) -> Result<()> {
let updated = request
.put(format!(
"{}/zones/{}/dns_records/{}",
self.api_url, zone_id, record_id,
))
.json(&json!({
"type": record_type,
"name": domain,
"content": address,
"ttl": domain.ttl,
"proxied": domain.proxied,
"comment": domain.comment,
}))
.bearer_auth(&self.api_token)
.send()
.await?
.json::<UpdatedResult>()
.await?;
if !updated.success {
return Err(anyhow!(
"Failed to update Cloudflare domain ({}) record",
domain.name
));
}
Ok(())
}

async fn create_dns_record(
&self,
request: &Client,
zone_id: &str,
record_type: &str,
domain: &Domain,
address: &IpAddr,
) -> Result<()> {
let created = request
.post(format!("{}/zones/{}/dns_records", self.api_url, zone_id))
.json(&json!({
"type": record_type,
"name": domain,
"content": address,
"ttl": domain.ttl,
"proxied": domain.proxied,
"comment": domain.comment,
}))
.bearer_auth(&self.api_token)
.send()
.await?
.json::<CreatedResult>()
.await?;
if !created.success {
return Err(anyhow!(
"Failed to create Cloudflare domain ({}) record",
domain.name
));
}
Ok(())
}
}

#[async_trait]
#[typetag::serde(name = "cloudflare")]
impl Provider for Cloudflare {
async fn update(&self, update: IpUpdate, request: Client) -> Result<bool> {
let zone_id = self.fetch_zone_id(&request).await?;
for domain in &self.domains {
for (version, address) in update.as_array() {
if address.is_none() {
continue;
}
let record_type = match version {
IpVersion::V4 => "A",
IpVersion::V6 => "AAAA",
};
let records = request
.get(format!("{}/zones/{zone_id}/dns_records", self.api_url))
.query(&[("name", &domain.name)])
.query(&[("type", record_type)])
.bearer_auth(&self.api_token)
.send()
.await?
.json::<RecordsList>()
.await?;
if let Some(record) = records.result.and_then(|vec| vec.into_iter().next()) {
let updated = request
.put(format!(
"{}/zones/{zone_id}/dns_records/{}",
self.api_url, record.id,
))
.json(&json!({
"type": record_type,
"name": domain.name,
"content": address,
"ttl": domain.ttl,
"proxied": domain.proxied,
"comment": domain.comment,
}))
.bearer_auth(&self.api_token)
.send()
if let Some(addr) = address {
let record_type = match version {
IpVersion::V4 => "A",
IpVersion::V6 => "AAAA",
};
if let Some(record) = self
.fetch_dns_records(&request, &zone_id, record_type, domain)
.await?
.json::<UpdatedResult>()
.first()
{
println!("Updating record: {}", record.id);
self.update_dns_record(
&request,
&zone_id,
&record.id,
record_type,
domain,
&addr,
)
.await?;
if !updated.success {
return Err(anyhow!(
"Failed to update Cloudflare domain ({}) record",
domain.name
));
} else {
println!("Creating record");
self.create_dns_record(&request, &zone_id, record_type, domain, &addr)
.await?;
}
continue;
}
let created = request
.post(format!("{}/zones/{zone_id}/dns_records", self.api_url))
.json(&json!({
"type": record_type,
"name": domain.name,
"content": address,
"ttl": domain.ttl,
"proxied": domain.proxied,
"comment": domain.comment,
}))
.bearer_auth(&self.api_token)
.send()
.await?
.json::<CreatedResult>()
.await?;
if !created.success {
return Err(anyhow!(
"Failed to create Cloudflare domain ({}) record",
domain.name
));
}
}
}
Expand Down Expand Up @@ -197,11 +256,11 @@ mod tests {
let provider = Cloudflare {
zone: "example.com".to_string(),
api_token: "bad_token".to_string(),
domains: smallvec![Domains {
domains: smallvec![Domain {
name: "example.com".to_string(),
ttl: 1,
proxied: true,
comment: None,
comment: "Created by DDRS".to_string(),
}],
api_url: mock.uri(),
};
Expand Down Expand Up @@ -241,11 +300,11 @@ mod tests {
let provider = Cloudflare {
zone: "example.com".to_string(),
api_token: "token".to_string(),
domains: smallvec![Domains {
domains: smallvec![Domain {
name: "example.com".to_string(),
ttl: 1,
proxied: true,
comment: None,
comment: "Created by DDRS".to_string(),
}],
api_url: mock.uri(),
};
Expand Down Expand Up @@ -316,11 +375,11 @@ mod tests {
let provider = Cloudflare {
zone: "example.com".to_string(),
api_token: "token".to_string(),
domains: smallvec![Domains {
domains: smallvec![Domain {
name: "example.com".to_string(),
ttl: 1,
proxied: true,
comment: Some("Created by DDRS".to_string()),
comment: "Created by DDRS".to_string(),
}],
api_url: mock.uri(),
};
Expand Down Expand Up @@ -426,11 +485,11 @@ mod tests {
let provider = Cloudflare {
zone: "example.com".to_string(),
api_token: "token".to_string(),
domains: smallvec![Domains {
domains: smallvec![Domain {
name: "example.com".to_string(),
ttl: 1,
proxied: true,
comment: Some("Created by DDRS".to_string()),
comment: "Created by DDRS".to_string(),
}],
api_url: mock.uri(),
};
Expand Down Expand Up @@ -503,11 +562,11 @@ mod tests {
let provider = Cloudflare {
zone: "example.com".to_string(),
api_token: "token".to_string(),
domains: smallvec![Domains {
domains: smallvec![Domain {
name: "example.com".to_string(),
ttl: 1,
proxied: true,
comment: Some("Created by DDRS".to_string()),
comment: "Created by DDRS".to_string(),
}],
api_url: mock.uri(),
};
Expand Down Expand Up @@ -582,11 +641,11 @@ mod tests {
let provider = Cloudflare {
zone: "example.com".to_string(),
api_token: "token".to_string(),
domains: smallvec![Domains {
domains: smallvec![Domain {
name: "example.com".to_string(),
ttl: 1,
proxied: true,
comment: Some("Created by DDRS".to_string()),
comment: "Created by DDRS".to_string(),
}],
api_url: mock.uri(),
};
Expand Down Expand Up @@ -672,11 +731,11 @@ mod tests {
let provider = Cloudflare {
zone: "example.com".to_string(),
api_token: "token".to_string(),
domains: smallvec![Domains {
domains: smallvec![Domain {
name: "example.com".to_string(),
ttl: 1,
proxied: true,
comment: Some("Created by DDRS".to_string()),
comment: "Created by DDRS".to_string(),
}],
api_url: mock.uri(),
};
Expand Down

0 comments on commit 29f6337

Please sign in to comment.