Skip to content

Commit c4ae026

Browse files
committed
Allow intercepting HTLCs based on the source channel
It may be useful in some situations to select HTLCs for interception based on the source channel in addition to the sink. Here we add the ability to do so by adding new flags to `HTLCInterceptionFlags`.
1 parent c0606a7 commit c4ae026

File tree

3 files changed

+212
-57
lines changed

3 files changed

+212
-57
lines changed

lightning/src/ln/channelmanager.rs

Lines changed: 64 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4767,7 +4767,9 @@ where
47674767
}
47684768
}
47694769

4770-
fn forward_needs_intercept_to_known_chan(&self, outbound_chan: &FundedChannel<SP>) -> bool {
4770+
fn forward_needs_intercept_to_known_chan(
4771+
&self, prev_chan_public: bool, outbound_chan: &FundedChannel<SP>,
4772+
) -> bool {
47714773
let intercept_flags = self.config.read().unwrap().htlc_interception_flags;
47724774
if !outbound_chan.context.should_announce() {
47734775
if outbound_chan.context.is_connected() {
@@ -4784,6 +4786,23 @@ where
47844786
return true;
47854787
}
47864788
}
4789+
if prev_chan_public {
4790+
if outbound_chan.context.should_announce() {
4791+
if intercept_flags & (HTLCInterceptionFlags::FromPublicToPublicChannels as u8) != 0
4792+
{
4793+
return true;
4794+
}
4795+
} else {
4796+
if intercept_flags & (HTLCInterceptionFlags::FromPublicToPrivateChannels as u8) != 0
4797+
{
4798+
return true;
4799+
}
4800+
}
4801+
} else {
4802+
if intercept_flags & (HTLCInterceptionFlags::FromPrivateChannels as u8) != 0 {
4803+
return true;
4804+
}
4805+
}
47874806
false
47884807
}
47894808

@@ -4877,7 +4896,7 @@ where
48774896
}
48784897

48794898
fn can_forward_htlc_should_intercept(
4880-
&self, msg: &msgs::UpdateAddHTLC, next_hop: &NextPacketDetails,
4899+
&self, msg: &msgs::UpdateAddHTLC, prev_chan_public: bool, next_hop: &NextPacketDetails,
48814900
) -> Result<bool, LocalHTLCFailureReason> {
48824901
let outgoing_scid = match next_hop.outgoing_connector {
48834902
HopConnector::ShortChannelId(scid) => scid,
@@ -4896,7 +4915,7 @@ where
48964915
// times we do it.
48974916
let intercept =
48984917
match self.do_funded_channel_callback(outgoing_scid, |chan: &mut FundedChannel<SP>| {
4899-
let intercept = self.forward_needs_intercept_to_known_chan(chan);
4918+
let intercept = self.forward_needs_intercept_to_known_chan(prev_chan_public, chan);
49004919
self.can_forward_htlc_to_outgoing_channel(chan, msg, next_hop, intercept)?;
49014920
Ok(intercept)
49024921
}) {
@@ -6842,34 +6861,29 @@ where
68426861
'outer_loop: for (incoming_scid_alias, update_add_htlcs) in decode_update_add_htlcs {
68436862
// If any decoded update_add_htlcs were processed, we need to persist.
68446863
should_persist = true;
6845-
let incoming_channel_details_opt = self.do_funded_channel_callback(
6846-
incoming_scid_alias,
6847-
|chan: &mut FundedChannel<SP>| {
6848-
let counterparty_node_id = chan.context.get_counterparty_node_id();
6849-
let channel_id = chan.context.channel_id();
6850-
let funding_txo = chan.funding.get_funding_txo().unwrap();
6851-
let user_channel_id = chan.context.get_user_id();
6852-
let accept_underpaying_htlcs = chan.context.config().accept_underpaying_htlcs;
6853-
(
6854-
counterparty_node_id,
6855-
channel_id,
6856-
funding_txo,
6857-
user_channel_id,
6858-
accept_underpaying_htlcs,
6859-
)
6860-
},
6861-
);
68626864
let (
68636865
incoming_counterparty_node_id,
68646866
incoming_channel_id,
68656867
incoming_funding_txo,
68666868
incoming_user_channel_id,
68676869
incoming_accept_underpaying_htlcs,
6868-
) = if let Some(incoming_channel_details) = incoming_channel_details_opt {
6869-
incoming_channel_details
6870-
} else {
6870+
incoming_chan_is_public,
6871+
) = match self.do_funded_channel_callback(
6872+
incoming_scid_alias,
6873+
|chan: &mut FundedChannel<SP>| {
6874+
(
6875+
chan.context.get_counterparty_node_id(),
6876+
chan.context.channel_id(),
6877+
chan.funding.get_funding_txo().unwrap(),
6878+
chan.context.get_user_id(),
6879+
chan.context.config().accept_underpaying_htlcs,
6880+
chan.context.should_announce(),
6881+
)
6882+
},
6883+
) {
6884+
Some(incoming_channel_details) => incoming_channel_details,
68716885
// The incoming channel no longer exists, HTLCs should be resolved onchain instead.
6872-
continue;
6886+
None => continue,
68736887
};
68746888

68756889
let mut htlc_forwards = Vec::new();
@@ -6989,9 +7003,11 @@ where
69897003
// Now process the HTLC on the outgoing channel if it's a forward.
69907004
let mut intercept_forward = false;
69917005
if let Some(next_packet_details) = next_packet_details_opt.as_ref() {
6992-
match self
6993-
.can_forward_htlc_should_intercept(&update_add_htlc, next_packet_details)
6994-
{
7006+
match self.can_forward_htlc_should_intercept(
7007+
&update_add_htlc,
7008+
incoming_chan_is_public,
7009+
next_packet_details,
7010+
) {
69957011
Err(reason) => {
69967012
fail_htlc_continue_to_next!(reason);
69977013
},
@@ -16317,9 +16333,29 @@ where
1631716333
);
1631816334
log_trace!(logger, "Releasing held htlc with intercept_id {}", intercept_id);
1631916335

16336+
let prev_chan_public = {
16337+
let per_peer_state = self.per_peer_state.read().unwrap();
16338+
let peer_state = per_peer_state
16339+
.get(&htlc.prev_counterparty_node_id)
16340+
.map(|mtx| mtx.lock().unwrap());
16341+
let chan_state = peer_state
16342+
.as_ref()
16343+
.map(|state| state.channel_by_id.get(&htlc.prev_channel_id))
16344+
.flatten();
16345+
if let Some(chan_state) = chan_state {
16346+
chan_state.context().should_announce()
16347+
} else {
16348+
// If the inbound channel has closed since the HTLC was held, we really
16349+
// shouldn't forward it - forwarding it now would result in, at best,
16350+
// having to claim the HTLC on chain. Instead, drop the HTLC and let the
16351+
// counterparty claim their money on chain.
16352+
return;
16353+
}
16354+
};
16355+
1632016356
let should_intercept = self
1632116357
.do_funded_channel_callback(next_hop_scid, |chan| {
16322-
self.forward_needs_intercept_to_known_chan(chan)
16358+
self.forward_needs_intercept_to_known_chan(prev_chan_public, chan)
1632316359
})
1632416360
.unwrap_or_else(|| self.forward_needs_intercept_to_unknown_chan(next_hop_scid));
1632516361

lightning/src/ln/interception_tests.rs

Lines changed: 101 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -50,37 +50,48 @@ fn do_test_htlc_interception_flags(
5050
let node_chanmgrs = create_node_chanmgrs(3, &node_cfgs, &[None, Some(intercept_config), None]);
5151
let nodes = create_network(3, &node_cfgs, &node_chanmgrs);
5252

53-
create_announced_chan_between_nodes(&nodes, 0, 1);
53+
let inbound_private = match flag {
54+
Flag::FromPrivateChannels => {
55+
create_unannounced_chan_between_nodes_with_value(&nodes, 0, 1, 100000, 0);
56+
true
57+
},
58+
_ => {
59+
create_announced_chan_between_nodes(&nodes, 0, 1);
60+
false
61+
},
62+
};
5463

5564
let node_0_id = nodes[0].node.get_our_node_id();
5665
let node_1_id = nodes[1].node.get_our_node_id();
5766
let node_2_id = nodes[2].node.get_our_node_id();
5867

5968
// First open the right type of channel (and get it in the right state) for the bit we're
6069
// testing.
61-
let (target_scid, target_chan_id) = match flag {
62-
Flag::ToOfflinePrivateChannels | Flag::ToOnlinePrivateChannels => {
70+
let (target_scid, target_chan_id, outbound_private_for_known_scids) = match flag {
71+
Flag::ToOfflinePrivateChannels
72+
| Flag::ToOnlinePrivateChannels
73+
| Flag::FromPublicToPrivateChannels => {
6374
create_unannounced_chan_between_nodes_with_value(&nodes, 1, 2, 100000, 0);
6475
let chan_id = nodes[2].node.list_channels()[0].channel_id;
6576
let scid = nodes[2].node.list_channels()[0].short_channel_id.unwrap();
6677
if flag == Flag::ToOfflinePrivateChannels {
6778
nodes[1].node.peer_disconnected(node_2_id);
6879
nodes[2].node.peer_disconnected(node_1_id);
69-
} else {
70-
assert_eq!(flag, Flag::ToOnlinePrivateChannels);
7180
}
72-
(scid, chan_id)
81+
(scid, chan_id, Some(true))
7382
},
74-
Flag::ToInterceptSCIDs | Flag::ToPublicChannels | Flag::ToUnknownSCIDs => {
83+
Flag::ToInterceptSCIDs
84+
| Flag::ToPublicChannels
85+
| Flag::FromPrivateChannels
86+
| Flag::FromPublicToPublicChannels
87+
| Flag::ToUnknownSCIDs => {
7588
let (chan_upd, _, chan_id, _) = create_announced_chan_between_nodes(&nodes, 1, 2);
7689
if flag == Flag::ToInterceptSCIDs {
77-
(nodes[1].node.get_intercept_scid(), chan_id)
78-
} else if flag == Flag::ToPublicChannels {
79-
(chan_upd.contents.short_channel_id, chan_id)
90+
(nodes[1].node.get_intercept_scid(), chan_id, None)
8091
} else if flag == Flag::ToUnknownSCIDs {
81-
(42424242, chan_id)
92+
(42424242, chan_id, None)
8293
} else {
83-
panic!();
94+
(chan_upd.contents.short_channel_id, chan_id, Some(false))
8495
}
8596
},
8697
_ => panic!("Combined flags aren't allowed"),
@@ -100,21 +111,50 @@ fn do_test_htlc_interception_flags(
100111
get_route_and_payment_hash!(nodes[0], nodes[2], pay_params, amt_msat);
101112
route.paths[0].hops[1].short_channel_id = target_scid;
102113

103-
let interception_bit_match = (flags_bitmask & (flag as u8)) != 0;
114+
let mut should_intercept = false;
115+
for a_flag in ALL_FLAGS {
116+
if flags_bitmask & (a_flag as u8) != 0 {
117+
match a_flag {
118+
Flag::ToInterceptSCIDs => {
119+
should_intercept |= flag == Flag::ToInterceptSCIDs;
120+
},
121+
Flag::ToOfflinePrivateChannels => {
122+
should_intercept |= flag == Flag::ToOfflinePrivateChannels;
123+
},
124+
Flag::ToOnlinePrivateChannels => {
125+
should_intercept |= flag != Flag::ToOfflinePrivateChannels
126+
&& outbound_private_for_known_scids == Some(true);
127+
},
128+
Flag::ToPublicChannels => {
129+
should_intercept |= outbound_private_for_known_scids == Some(false);
130+
},
131+
Flag::ToUnknownSCIDs => {
132+
should_intercept |= flag == Flag::ToUnknownSCIDs;
133+
},
134+
Flag::FromPrivateChannels => {
135+
should_intercept |= inbound_private;
136+
},
137+
Flag::FromPublicToPrivateChannels => {
138+
should_intercept |=
139+
!inbound_private && outbound_private_for_known_scids == Some(true);
140+
},
141+
Flag::FromPublicToPublicChannels => {
142+
should_intercept |=
143+
!inbound_private && outbound_private_for_known_scids == Some(false);
144+
},
145+
_ => panic!("Combined flags aren't allowed"),
146+
}
147+
}
148+
}
149+
104150
match modification {
105151
Some(ForwardingMod::FeeTooLow) => {
106-
assert!(
107-
interception_bit_match,
108-
"No reason to test failing if we aren't trying to intercept",
109-
);
152+
assert!(should_intercept, "No reason to test failing if we aren't trying to intercept");
110153
route.paths[0].hops[0].fee_msat = 500;
111154
},
112155
Some(ForwardingMod::CLTVBelowConfig) => {
113156
route.paths[0].hops[0].cltv_expiry_delta = 6 * 12;
114-
assert!(
115-
interception_bit_match,
116-
"No reason to test failing if we aren't trying to intercept",
117-
);
157+
assert!(should_intercept, "No reason to test failing if we aren't trying to intercept");
118158
},
119159
Some(ForwardingMod::CLTVBelowMin) => {
120160
route.paths[0].hops[0].cltv_expiry_delta = 6;
@@ -132,7 +172,7 @@ fn do_test_htlc_interception_flags(
132172
do_commitment_signed_dance(&nodes[1], &nodes[0], &payment_event.commitment_msg, false, true);
133173
expect_and_process_pending_htlcs(&nodes[1], false);
134174

135-
if interception_bit_match && modification.is_none() {
175+
if should_intercept && modification.is_none() {
136176
// If we were set to intercept, check that we got an interception event then
137177
// forward the HTLC on to nodes[2] and claim the payment.
138178
let intercept_id;
@@ -171,7 +211,14 @@ fn do_test_htlc_interception_flags(
171211
// If we were not set to intercept, check that the HTLC either failed or was
172212
// automatically forwarded as appropriate.
173213
match (modification, flag) {
174-
(None, Flag::ToOnlinePrivateChannels | Flag::ToPublicChannels) => {
214+
(
215+
None,
216+
Flag::ToOnlinePrivateChannels
217+
| Flag::ToPublicChannels
218+
| Flag::FromPrivateChannels
219+
| Flag::FromPublicToPrivateChannels
220+
| Flag::FromPublicToPublicChannels,
221+
) => {
175222
check_added_monitors(&nodes[1], 1);
176223

177224
let forward_ev = SendEvent::from_node(&nodes[1]);
@@ -240,31 +287,55 @@ fn do_test_htlc_interception_flags(
240287
}
241288

242289
const MAX_BITMASK: u8 = HTLCInterceptionFlags::AllValidHTLCs as u8;
243-
const ALL_FLAGS: [HTLCInterceptionFlags; 5] = [
290+
const ALL_FLAGS: [HTLCInterceptionFlags; 8] = [
244291
HTLCInterceptionFlags::ToInterceptSCIDs,
245292
HTLCInterceptionFlags::ToOfflinePrivateChannels,
246293
HTLCInterceptionFlags::ToOnlinePrivateChannels,
247294
HTLCInterceptionFlags::ToPublicChannels,
248295
HTLCInterceptionFlags::ToUnknownSCIDs,
296+
HTLCInterceptionFlags::FromPrivateChannels,
297+
HTLCInterceptionFlags::FromPublicToPrivateChannels,
298+
HTLCInterceptionFlags::FromPublicToPublicChannels,
249299
];
250-
251300
#[test]
252-
fn test_htlc_interception_flags() {
301+
fn check_all_flags() {
253302
let mut all_flag_bits = 0;
254303
for flag in ALL_FLAGS {
255304
all_flag_bits |= flag as isize;
256305
}
257306
assert_eq!(all_flag_bits, MAX_BITMASK as isize, "all flags must test all bits");
307+
}
258308

309+
fn test_htlc_interception_flags_subrange<I: Iterator<Item = u8>>(r: I) {
259310
// Test all 2^5 = 32 combinations of the HTLCInterceptionFlags bitmask
260311
// For each combination, test 5 different HTLC forwards and verify correct interception behavior
261-
for flags_bitmask in 0..=MAX_BITMASK {
312+
for flags_bitmask in r {
262313
for flag in ALL_FLAGS {
263314
do_test_htlc_interception_flags(flags_bitmask, flag, None);
264315
}
265316
}
266317
}
267318

319+
#[test]
320+
fn test_htlc_interception_flags_a() {
321+
test_htlc_interception_flags_subrange(0..MAX_BITMASK / 4);
322+
}
323+
324+
#[test]
325+
fn test_htlc_interception_flags_b() {
326+
test_htlc_interception_flags_subrange(MAX_BITMASK / 4..MAX_BITMASK / 2);
327+
}
328+
329+
#[test]
330+
fn test_htlc_interception_flags_c() {
331+
test_htlc_interception_flags_subrange(MAX_BITMASK / 2..MAX_BITMASK / 4 * 3);
332+
}
333+
334+
#[test]
335+
fn test_htlc_interception_flags_d() {
336+
test_htlc_interception_flags_subrange(MAX_BITMASK / 4 * 3..=MAX_BITMASK);
337+
}
338+
268339
#[test]
269340
fn test_htlc_bad_for_chan_config() {
270341
// Test that interception won't be done if an HTLC fails to meet the target channel's channel
@@ -273,6 +344,9 @@ fn test_htlc_bad_for_chan_config() {
273344
HTLCInterceptionFlags::ToOfflinePrivateChannels,
274345
HTLCInterceptionFlags::ToOnlinePrivateChannels,
275346
HTLCInterceptionFlags::ToPublicChannels,
347+
HTLCInterceptionFlags::FromPrivateChannels,
348+
HTLCInterceptionFlags::FromPublicToPrivateChannels,
349+
HTLCInterceptionFlags::FromPublicToPublicChannels,
276350
];
277351
for flag in have_chan_flags {
278352
do_test_htlc_interception_flags(flag as u8, flag, Some(ForwardingMod::FeeTooLow));

0 commit comments

Comments
 (0)