Skip to content

Commit

Permalink
Use ShortChannelId from cnl-rpc
Browse files Browse the repository at this point in the history
The cln-rpc crate has a ShortChannelId.
We've decided to use that version here. If we ever separate
out the primitives from cln_rpc we might be able to use them
without breaking changes.
  • Loading branch information
ErikDeSmedt committed Oct 27, 2023
1 parent e6ea869 commit 6f3d6b8
Showing 1 changed file with 59 additions and 91 deletions.
150 changes: 59 additions & 91 deletions libs/gl-client/src/lsps/lsps0/common_schemas.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
use core::str::FromStr;

use std::fmt::{Display, Formatter};

use anyhow::{anyhow, Context};

use serde::de::Error as SeError;
use serde::ser::Error as DeError;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Serialize, Deserializer, Serializer};

use time::format_description::FormatItem;
use time::macros::format_description;
Expand Down Expand Up @@ -144,91 +150,60 @@ impl<'de> Deserialize<'de> for MsatAmount {
}
}

#[derive(Debug, PartialEq)]
pub struct ShortChannelId {
scid: u64,
}

// constants for parsing of short_channel_id in bits
const SCID_BLOCK_HEIGHT_BITSHIFT: u64 = 24 + 16;
const SCID_TXID_BITSHIFT: u64 = 16;

impl ShortChannelId {
pub fn new_from_u64(scid: u64) -> Self {
Self { scid }
}

// The scid or short channel id consits out of 8 bytes
//
// It is
// - 3 bytes for block_height
// - 3 bytes for transaction index in the block
// - 2 bytes for output_index paying to that channel
//
// The string representation 812x10x2 refers to the
// channel that was funded by the 2nd output-index
// of the 10th transaction in block 812.
pub fn new_from_str(scid: &str) -> Option<Self> {
// TODO: Come up with a better error type
let splits: Vec<u64> = scid
.split('x')
.map(|x| x.parse::<u64>())
.collect::<Result<Vec<u64>, std::num::ParseIntError>>()
.ok()?;

if splits.len() != 3 {
return None;
};

const MAX_VALUE_3_BYTES: u64 = 0xFFFFFF;
const MAX_VALUE_2_BYTES: u64 = 0xFFFF;

let block_height = splits[0];
let txid = splits[1];
let v_out = splits[2];

if block_height > MAX_VALUE_3_BYTES || txid > MAX_VALUE_3_BYTES || v_out > MAX_VALUE_2_BYTES
{
return None;
}

let result: u64 =
(block_height << SCID_BLOCK_HEIGHT_BITSHIFT) | (txid << SCID_TXID_BITSHIFT) | (v_out);

Some(Self::new_from_u64(result))
}

pub fn value_as_u64(&self) -> u64 {
self.scid
}

pub fn value_as_string(&self) -> String {
let block_height = self.scid >> SCID_BLOCK_HEIGHT_BITSHIFT & 0xFFFFFF;
let txid = self.scid >> SCID_TXID_BITSHIFT & 0xFFFFFF;
let v_out = self.scid & 0xFFFF;

format!("{block_height}x{txid}x{v_out}")
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ShortChannelId(u64);

impl Serialize for ShortChannelId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
S: Serializer,
{
let str_repr = self.value_as_string();
serializer.serialize_str(&str_repr)
serializer.serialize_str(&self.to_string())
}
}

impl<'de> Deserialize<'de> for ShortChannelId {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
D: Deserializer<'de>,
{
let scid_str: String = String::deserialize(deserializer)?;
ShortChannelId::new_from_str(&scid_str)
.ok_or_else(|| D::Error::custom(format!("Invalid scid: {}", scid_str)))
use serde::de::Error;
let s: String = Deserialize::deserialize(deserializer)?;
Ok(Self::from_str(&s).map_err(|e| Error::custom(e.to_string()))?)
}
}

impl FromStr for ShortChannelId {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let parts: Result<Vec<u64>, _> = s.split('x').map(|p| p.parse()).collect();
let parts = parts.with_context(|| format!("Malformed short_channel_id: {}", s))?;
if parts.len() != 3 {
return Err(anyhow!(
"Malformed short_channel_id: element count mismatch"
));
}

Ok(ShortChannelId(
(parts[0] << 40) | (parts[1] << 16) | (parts[2] << 0),
))
}
}
impl Display for ShortChannelId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}x{}x{}", self.block(), self.txindex(), self.outnum())
}
}
impl ShortChannelId {
pub fn block(&self) -> u32 {
(self.0 >> 40) as u32 & 0xFFFFFF
}
pub fn txindex(&self) -> u32 {
(self.0 >> 16) as u32 & 0xFFFFFF
}
pub fn outnum(&self) -> u16 {
self.0 as u16 & 0xFFFF
}
}

Expand Down Expand Up @@ -301,37 +276,30 @@ mod test {
// The string representation are the same numbers separated by the letter x

// Test the largest possible value
let scid_u64 = 0xFFFFFF_FFFFFF_FFFF;
let scid_str = "16777215x16777215x65535";

let scid = ShortChannelId::new_from_str(scid_str).expect("The scid is parseable");
assert_eq!(scid.value_as_string(), scid_str);
assert_eq!(scid.value_as_u64(), scid_u64);
let scid = ShortChannelId::from_str(scid_str).expect("The scid is parseable");
assert_eq!(scid.to_string(), scid_str);

// Test the smallest possible value
let scid_u64 = 0x000000_000000_0000;
let scid_str = "0x0x0";

let scid = ShortChannelId::new_from_str(scid_str).expect("The scid is parseable");
assert_eq!(scid.value_as_string(), scid_str);
assert_eq!(scid.value_as_u64(), scid_u64);
let scid = ShortChannelId::from_str(scid_str).expect("The scid is parseable");
assert_eq!(scid.to_string(), scid_str);

// A sorted value to check the ordering of the fields
let scid_u64 = 0x000001_000002_0003;
let scid_str = "1x2x3";

let scid = ShortChannelId::new_from_str(scid_str).expect("The scid is parseable");
assert_eq!(scid.value_as_string(), scid_str);
assert_eq!(scid.value_as_u64(), scid_u64);
let scid = ShortChannelId::from_str(scid_str).expect("The scid is parseable");
assert_eq!(scid.to_string(), scid_str);
// A couple of unparseable scids
assert!(ShortChannelId::new_from_str("xx").is_none());
assert!(ShortChannelId::new_from_str("0x0").is_none());
assert!(ShortChannelId::new_from_str("-2x-12x14").is_none());
assert!(ShortChannelId::from_str("xx").is_err());
assert!(ShortChannelId::from_str("0x0").is_err());
assert!(ShortChannelId::from_str("-2x-12x14").is_err());
}

#[test]
fn short_channel_id_is_serialized_as_str() {
let scid: ShortChannelId = ShortChannelId::new_from_str("10x5x8").unwrap();
let scid: ShortChannelId = ShortChannelId::from_str("10x5x8").unwrap();
let scid_json_obj = serde_json::to_string(&scid).expect("Can be serialized");
assert_eq!("\"10x5x8\"", scid_json_obj);
}
Expand All @@ -342,6 +310,6 @@ mod test {

let scid = serde_json::from_str::<ShortChannelId>(scid_json).expect("scid can be parsed");

assert_eq!(scid, ShortChannelId::new_from_str("11x12x13").unwrap());
assert_eq!(scid, ShortChannelId::from_str("11x12x13").unwrap());
}
}

0 comments on commit 6f3d6b8

Please sign in to comment.