diff --git a/libs/gl-client/src/lsps/lsps0/common_schemas.rs b/libs/gl-client/src/lsps/lsps0/common_schemas.rs index 7eb5efde1..e5d4bfe54 100644 --- a/libs/gl-client/src/lsps/lsps0/common_schemas.rs +++ b/libs/gl-client/src/lsps/lsps0/common_schemas.rs @@ -1,6 +1,12 @@ +use core::str::FromStr; + +use std::fmt::{Display, Formatter}; + +use anyhow::{anyhow, Context}; + use serde::de::Error as SeError; use serde::ser::Error as DeError; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Serialize, Deserializer, Serializer}; use time::format_description::FormatItem; use time::macros::format_description; @@ -144,91 +150,60 @@ impl<'de> Deserialize<'de> for MsatAmount { } } -#[derive(Debug, PartialEq)] -pub struct ShortChannelId { - scid: u64, -} - -// constants for parsing of short_channel_id in bits -const SCID_BLOCK_HEIGHT_BITSHIFT: u64 = 24 + 16; -const SCID_TXID_BITSHIFT: u64 = 16; - -impl ShortChannelId { - pub fn new_from_u64(scid: u64) -> Self { - Self { scid } - } - - // The scid or short channel id consits out of 8 bytes - // - // It is - // - 3 bytes for block_height - // - 3 bytes for transaction index in the block - // - 2 bytes for output_index paying to that channel - // - // The string representation 812x10x2 refers to the - // channel that was funded by the 2nd output-index - // of the 10th transaction in block 812. - pub fn new_from_str(scid: &str) -> Option { - // TODO: Come up with a better error type - let splits: Vec = scid - .split('x') - .map(|x| x.parse::()) - .collect::, std::num::ParseIntError>>() - .ok()?; - - if splits.len() != 3 { - return None; - }; - - const MAX_VALUE_3_BYTES: u64 = 0xFFFFFF; - const MAX_VALUE_2_BYTES: u64 = 0xFFFF; - - let block_height = splits[0]; - let txid = splits[1]; - let v_out = splits[2]; - - if block_height > MAX_VALUE_3_BYTES || txid > MAX_VALUE_3_BYTES || v_out > MAX_VALUE_2_BYTES - { - return None; - } - - let result: u64 = - (block_height << SCID_BLOCK_HEIGHT_BITSHIFT) | (txid << SCID_TXID_BITSHIFT) | (v_out); - - Some(Self::new_from_u64(result)) - } - - pub fn value_as_u64(&self) -> u64 { - self.scid - } - - pub fn value_as_string(&self) -> String { - let block_height = self.scid >> SCID_BLOCK_HEIGHT_BITSHIFT & 0xFFFFFF; - let txid = self.scid >> SCID_TXID_BITSHIFT & 0xFFFFFF; - let v_out = self.scid & 0xFFFF; - format!("{block_height}x{txid}x{v_out}") - } -} +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ShortChannelId(u64); impl Serialize for ShortChannelId { fn serialize(&self, serializer: S) -> Result where - S: serde::Serializer, + S: Serializer, { - let str_repr = self.value_as_string(); - serializer.serialize_str(&str_repr) + serializer.serialize_str(&self.to_string()) } } impl<'de> Deserialize<'de> for ShortChannelId { fn deserialize(deserializer: D) -> Result where - D: serde::Deserializer<'de>, + D: Deserializer<'de>, { - let scid_str: String = String::deserialize(deserializer)?; - ShortChannelId::new_from_str(&scid_str) - .ok_or_else(|| D::Error::custom(format!("Invalid scid: {}", scid_str))) + use serde::de::Error; + let s: String = Deserialize::deserialize(deserializer)?; + Ok(Self::from_str(&s).map_err(|e| Error::custom(e.to_string()))?) + } +} + +impl FromStr for ShortChannelId { + type Err = anyhow::Error; + fn from_str(s: &str) -> Result { + let parts: Result, _> = s.split('x').map(|p| p.parse()).collect(); + let parts = parts.with_context(|| format!("Malformed short_channel_id: {}", s))?; + if parts.len() != 3 { + return Err(anyhow!( + "Malformed short_channel_id: element count mismatch" + )); + } + + Ok(ShortChannelId( + (parts[0] << 40) | (parts[1] << 16) | (parts[2] << 0), + )) + } +} +impl Display for ShortChannelId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}x{}x{}", self.block(), self.txindex(), self.outnum()) + } +} +impl ShortChannelId { + pub fn block(&self) -> u32 { + (self.0 >> 40) as u32 & 0xFFFFFF + } + pub fn txindex(&self) -> u32 { + (self.0 >> 16) as u32 & 0xFFFFFF + } + pub fn outnum(&self) -> u16 { + self.0 as u16 & 0xFFFF } } @@ -301,37 +276,30 @@ mod test { // The string representation are the same numbers separated by the letter x // Test the largest possible value - let scid_u64 = 0xFFFFFF_FFFFFF_FFFF; let scid_str = "16777215x16777215x65535"; - let scid = ShortChannelId::new_from_str(scid_str).expect("The scid is parseable"); - assert_eq!(scid.value_as_string(), scid_str); - assert_eq!(scid.value_as_u64(), scid_u64); + let scid = ShortChannelId::from_str(scid_str).expect("The scid is parseable"); + assert_eq!(scid.to_string(), scid_str); // Test the smallest possible value - let scid_u64 = 0x000000_000000_0000; let scid_str = "0x0x0"; - let scid = ShortChannelId::new_from_str(scid_str).expect("The scid is parseable"); - assert_eq!(scid.value_as_string(), scid_str); - assert_eq!(scid.value_as_u64(), scid_u64); + let scid = ShortChannelId::from_str(scid_str).expect("The scid is parseable"); + assert_eq!(scid.to_string(), scid_str); - // A sorted value to check the ordering of the fields - let scid_u64 = 0x000001_000002_0003; let scid_str = "1x2x3"; - let scid = ShortChannelId::new_from_str(scid_str).expect("The scid is parseable"); - assert_eq!(scid.value_as_string(), scid_str); - assert_eq!(scid.value_as_u64(), scid_u64); + let scid = ShortChannelId::from_str(scid_str).expect("The scid is parseable"); + assert_eq!(scid.to_string(), scid_str); // A couple of unparseable scids - assert!(ShortChannelId::new_from_str("xx").is_none()); - assert!(ShortChannelId::new_from_str("0x0").is_none()); - assert!(ShortChannelId::new_from_str("-2x-12x14").is_none()); + assert!(ShortChannelId::from_str("xx").is_err()); + assert!(ShortChannelId::from_str("0x0").is_err()); + assert!(ShortChannelId::from_str("-2x-12x14").is_err()); } #[test] fn short_channel_id_is_serialized_as_str() { - let scid: ShortChannelId = ShortChannelId::new_from_str("10x5x8").unwrap(); + let scid: ShortChannelId = ShortChannelId::from_str("10x5x8").unwrap(); let scid_json_obj = serde_json::to_string(&scid).expect("Can be serialized"); assert_eq!("\"10x5x8\"", scid_json_obj); } @@ -342,6 +310,6 @@ mod test { let scid = serde_json::from_str::(scid_json).expect("scid can be parsed"); - assert_eq!(scid, ShortChannelId::new_from_str("11x12x13").unwrap()); + assert_eq!(scid, ShortChannelId::from_str("11x12x13").unwrap()); } }