diff --git a/libs/gl-client-py/tests/plugins/trmp_htlc_hook.py b/libs/gl-client-py/tests/plugins/trmp_htlc_hook.py index 17d6addc9..066cda377 100755 --- a/libs/gl-client-py/tests/plugins/trmp_htlc_hook.py +++ b/libs/gl-client-py/tests/plugins/trmp_htlc_hook.py @@ -1,20 +1,90 @@ #!/usr/bin/env python3 from pyln.client import Plugin +from binascii import unhexlify +from pyln.proto.primitives import varint_decode +from io import BytesIO + +INVOICE_TYPE = 33001 +AMOUNT_TYPE = 33003 plugin = Plugin( dynamic=False, init_features=1 << 427, ) +plugin.check_invoice = None +plugin.check_amount = None +plugin.payment_key = None + + @plugin.hook("htlc_accepted") def on_htlc_accepted(htlc, onion, plugin, **kwargs): - plugin.log(f"Got onion {onion}") + plugin.log(f"got onion {onion}") + + payment_metadata = unhexlify(onion["payment_metadata"]) + payment_metadata = BytesIO(payment_metadata) + + invoice_type = varint_decode(payment_metadata) + invoice_length = varint_decode(payment_metadata) + invoice_value = payment_metadata.read(invoice_length) + assert invoice_type == INVOICE_TYPE + + if plugin.check_invoice is not None: + plugin.log( + f"check invoice {invoice_value.decode('utf-8')} matches {plugin.check_invoice}" + ) + assert invoice_value.decode("utf-8") == plugin.check_invoice + + amount_msat_type = varint_decode(payment_metadata) + amount_msat_length = varint_decode(payment_metadata) + amount_msat_value = payment_metadata.read(amount_msat_length) + assert amount_msat_type == AMOUNT_TYPE + + if plugin.check_amount is not None: + plugin.log( + f"check amount_msat {int.from_bytes(amount_msat_value, 'big')} matches {plugin.check_amount}" + ) + assert int.from_bytes(amount_msat_value, "big") == plugin.check_amount + + plugin.log(f"got invoice={invoice_value.decode('utf-8')}") + plugin.log(f"got amount_msat={int.from_bytes(amount_msat_value, 'big')}") + + if plugin.payment_key is not None: + plugin.log(f"resolving with payment_key {plugin.payment_key}") + return { + "result": "resolve", + "payment_key": plugin.payment_key, + } + + replacement = onion["payload"][6:102] + plugin.log(f"replace onion payload with {replacement}") + return {"result": "continue", "payload": replacement} + + +@plugin.method("setpaymentkey") +def setpaymentkey(plugin, payment_key): + """Sets the payment_key to resolve an htlc""" + plugin.payment_key = payment_key + + +@plugin.method("setcheckinvoice") +def setcheckinvoice(plugin, invoice): + """Sets an invoice check""" + plugin.check_invoice = invoice + + +@plugin.method("setcheckamount") +def setcheckamount(plugin, amount_msat): + """Sets an amount check""" + plugin.check_amount = amount_msat + - # Strip off custom payload as we are the last hop. - new_payload = onion["payload"][6:102] - plugin.log(f"Replace onion payload with {new_payload}") - return {"result": "continue", "payload": new_payload} +@plugin.method("unsetchecks") +def unsetchecks(plugin): + """Unsets all checks""" + plugin.check_invoice = None + plugin.check_amount = None plugin.run() diff --git a/libs/gl-client-py/tests/test_plugin.py b/libs/gl-client-py/tests/test_plugin.py index f56d628b5..e68c64ade 100644 --- a/libs/gl-client-py/tests/test_plugin.py +++ b/libs/gl-client-py/tests/test_plugin.py @@ -75,14 +75,24 @@ def test_trampoline_pay(bitcoind, clients, node_factory): # create invoice and pay via trampoline. Trampoline is actually the # same node as the destination but we don't care as we just want to # test the business logic. + invoice_preimage = ( + "17b08f669513b7379728fc1abcea5eaf3448bc1eba55a68ca2cd1843409cdc04" + ) inv = l2.rpc.invoice( amount_msat=50000000, label="trampoline-pay-test", description="trampoline-pay-test", + preimage=invoice_preimage, ) + l2.rpc.setpaymentkey(invoice_preimage) + l2.rpc.setcheckinvoice(inv["bolt11"]) + l2.rpc.setcheckamount(50000000) + res = n1.trampoline_pay(inv["bolt11"], bytes.fromhex(l2.info["id"])) assert res + l2.rpc.unsetchecks() + # settle channel htlcs bitcoind.generate_block(10) wait_for( diff --git a/libs/gl-plugin/src/tlv.rs b/libs/gl-plugin/src/tlv.rs index 3324d7c83..ce4b0ef3c 100644 --- a/libs/gl-plugin/src/tlv.rs +++ b/libs/gl-plugin/src/tlv.rs @@ -190,3 +190,20 @@ impl<'de> Deserialize<'de> for SerializedTlvStream { Self::from_bytes(b.into_inner()).map_err(|e| serde::de::Error::custom(e.to_string())) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tlv_stream() { + let raw_hex = "fd80e9fd01046c6e62633130306e31706e6670757a677070356a73376e653571727465326d707874666d7a703638667838676a7376776b3034366c76357a76707a7832766d687478683763777364717163717a7a737871797a3576717370357a6d6b78726539686d6864617378336b75357070336a38366472337778656b78336437383363706c6161363068783870357564733971787071797367717132646c68656177796c677534346567393363766e78666e64747a646a6e647465666b726861727763746b3368783766656e67346179746e6a3277686d74716665636a7930776777396c6665727072386b686d64667771736e386d6d7a3776643565776a34756370656439787076fd80eb022742"; + let raw = hex::decode(&raw_hex).unwrap(); + let tlv_stream = SerializedTlvStream::from_bytes(raw).unwrap(); + let invoice = tlv_stream.get(33001); + let amount_msat = tlv_stream.get(33003); + assert!(invoice.is_some()); + assert!(amount_msat + .is_some_and(|v| u16::from_be_bytes(v.value[..].try_into().unwrap()) == 10050)); + } +} diff --git a/libs/gl-plugin/src/tramp.rs b/libs/gl-plugin/src/tramp.rs index b1a4b6e4b..f648d0d13 100644 --- a/libs/gl-plugin/src/tramp.rs +++ b/libs/gl-plugin/src/tramp.rs @@ -114,6 +114,7 @@ pub async fn trampolinepay( // We need to add some sats to the htlcs to allow the trampoline node // to pay fees on routing. + let tlv_amount_msat = amount_msat; let overpay = amount_msat as f64 * (as_option(req.maxfeepercent).unwrap_or(DEFAULT_OVERPAY_PERCENT) as f64 / 100 as f64); let amount_msat = amount_msat + overpay as u64; @@ -198,7 +199,7 @@ pub async fn trampolinepay( use crate::tlv::{SerializedTlvStream, ToBytes}; let mut payload: SerializedTlvStream = SerializedTlvStream::new(); payload.set_bytes(TLV_BOLT11, req.bolt11.as_bytes()); - payload.set_tu64(TLV_AMT_MSAT, amount_msat); + payload.set_tu64(TLV_AMT_MSAT, tlv_amount_msat); let payload_hex = hex::encode(SerializedTlvStream::to_bytes(payload)); let mut part_id = if choosen.len() == 1 { 0 } else { 1 };