@@ -50,37 +50,48 @@ fn do_test_htlc_interception_flags(
5050 let node_chanmgrs = create_node_chanmgrs ( 3 , & node_cfgs, & [ None , Some ( intercept_config) , None ] ) ;
5151 let nodes = create_network ( 3 , & node_cfgs, & node_chanmgrs) ;
5252
53- create_announced_chan_between_nodes ( & nodes, 0 , 1 ) ;
53+ let inbound_private = match flag {
54+ Flag :: FromPrivateChannels => {
55+ create_unannounced_chan_between_nodes_with_value ( & nodes, 0 , 1 , 100000 , 0 ) ;
56+ true
57+ } ,
58+ _ => {
59+ create_announced_chan_between_nodes ( & nodes, 0 , 1 ) ;
60+ false
61+ } ,
62+ } ;
5463
5564 let node_0_id = nodes[ 0 ] . node . get_our_node_id ( ) ;
5665 let node_1_id = nodes[ 1 ] . node . get_our_node_id ( ) ;
5766 let node_2_id = nodes[ 2 ] . node . get_our_node_id ( ) ;
5867
5968 // First open the right type of channel (and get it in the right state) for the bit we're
6069 // testing.
61- let ( target_scid, target_chan_id) = match flag {
62- Flag :: ToOfflinePrivateChannels | Flag :: ToOnlinePrivateChannels => {
70+ let ( target_scid, target_chan_id, outbound_private_for_known_scids) = match flag {
71+ Flag :: ToOfflinePrivateChannels
72+ | Flag :: ToOnlinePrivateChannels
73+ | Flag :: FromPublicToPrivateChannels => {
6374 create_unannounced_chan_between_nodes_with_value ( & nodes, 1 , 2 , 100000 , 0 ) ;
6475 let chan_id = nodes[ 2 ] . node . list_channels ( ) [ 0 ] . channel_id ;
6576 let scid = nodes[ 2 ] . node . list_channels ( ) [ 0 ] . short_channel_id . unwrap ( ) ;
6677 if flag == Flag :: ToOfflinePrivateChannels {
6778 nodes[ 1 ] . node . peer_disconnected ( node_2_id) ;
6879 nodes[ 2 ] . node . peer_disconnected ( node_1_id) ;
69- } else {
70- assert_eq ! ( flag, Flag :: ToOnlinePrivateChannels ) ;
7180 }
72- ( scid, chan_id)
81+ ( scid, chan_id, Some ( true ) )
7382 } ,
74- Flag :: ToInterceptSCIDs | Flag :: ToPublicChannels | Flag :: ToUnknownSCIDs => {
83+ Flag :: ToInterceptSCIDs
84+ | Flag :: ToPublicChannels
85+ | Flag :: FromPrivateChannels
86+ | Flag :: FromPublicToPublicChannels
87+ | Flag :: ToUnknownSCIDs => {
7588 let ( chan_upd, _, chan_id, _) = create_announced_chan_between_nodes ( & nodes, 1 , 2 ) ;
7689 if flag == Flag :: ToInterceptSCIDs {
77- ( nodes[ 1 ] . node . get_intercept_scid ( ) , chan_id)
78- } else if flag == Flag :: ToPublicChannels {
79- ( chan_upd. contents . short_channel_id , chan_id)
90+ ( nodes[ 1 ] . node . get_intercept_scid ( ) , chan_id, None )
8091 } else if flag == Flag :: ToUnknownSCIDs {
81- ( 42424242 , chan_id)
92+ ( 42424242 , chan_id, None )
8293 } else {
83- panic ! ( ) ;
94+ ( chan_upd . contents . short_channel_id , chan_id , Some ( false ) )
8495 }
8596 } ,
8697 _ => panic ! ( "Combined flags aren't allowed" ) ,
@@ -100,21 +111,50 @@ fn do_test_htlc_interception_flags(
100111 get_route_and_payment_hash ! ( nodes[ 0 ] , nodes[ 2 ] , pay_params, amt_msat) ;
101112 route. paths [ 0 ] . hops [ 1 ] . short_channel_id = target_scid;
102113
103- let interception_bit_match = ( flags_bitmask & ( flag as u8 ) ) != 0 ;
114+ let mut should_intercept = false ;
115+ for a_flag in ALL_FLAGS {
116+ if flags_bitmask & ( a_flag as u8 ) != 0 {
117+ match a_flag {
118+ Flag :: ToInterceptSCIDs => {
119+ should_intercept |= flag == Flag :: ToInterceptSCIDs ;
120+ } ,
121+ Flag :: ToOfflinePrivateChannels => {
122+ should_intercept |= flag == Flag :: ToOfflinePrivateChannels ;
123+ } ,
124+ Flag :: ToOnlinePrivateChannels => {
125+ should_intercept |= flag != Flag :: ToOfflinePrivateChannels
126+ && outbound_private_for_known_scids == Some ( true ) ;
127+ } ,
128+ Flag :: ToPublicChannels => {
129+ should_intercept |= outbound_private_for_known_scids == Some ( false ) ;
130+ } ,
131+ Flag :: ToUnknownSCIDs => {
132+ should_intercept |= flag == Flag :: ToUnknownSCIDs ;
133+ } ,
134+ Flag :: FromPrivateChannels => {
135+ should_intercept |= inbound_private;
136+ } ,
137+ Flag :: FromPublicToPrivateChannels => {
138+ should_intercept |=
139+ !inbound_private && outbound_private_for_known_scids == Some ( true ) ;
140+ } ,
141+ Flag :: FromPublicToPublicChannels => {
142+ should_intercept |=
143+ !inbound_private && outbound_private_for_known_scids == Some ( false ) ;
144+ } ,
145+ _ => panic ! ( "Combined flags aren't allowed" ) ,
146+ }
147+ }
148+ }
149+
104150 match modification {
105151 Some ( ForwardingMod :: FeeTooLow ) => {
106- assert ! (
107- interception_bit_match,
108- "No reason to test failing if we aren't trying to intercept" ,
109- ) ;
152+ assert ! ( should_intercept, "No reason to test failing if we aren't trying to intercept" ) ;
110153 route. paths [ 0 ] . hops [ 0 ] . fee_msat = 500 ;
111154 } ,
112155 Some ( ForwardingMod :: CLTVBelowConfig ) => {
113156 route. paths [ 0 ] . hops [ 0 ] . cltv_expiry_delta = 6 * 12 ;
114- assert ! (
115- interception_bit_match,
116- "No reason to test failing if we aren't trying to intercept" ,
117- ) ;
157+ assert ! ( should_intercept, "No reason to test failing if we aren't trying to intercept" ) ;
118158 } ,
119159 Some ( ForwardingMod :: CLTVBelowMin ) => {
120160 route. paths [ 0 ] . hops [ 0 ] . cltv_expiry_delta = 6 ;
@@ -132,7 +172,7 @@ fn do_test_htlc_interception_flags(
132172 do_commitment_signed_dance ( & nodes[ 1 ] , & nodes[ 0 ] , & payment_event. commitment_msg , false , true ) ;
133173 expect_and_process_pending_htlcs ( & nodes[ 1 ] , false ) ;
134174
135- if interception_bit_match && modification. is_none ( ) {
175+ if should_intercept && modification. is_none ( ) {
136176 // If we were set to intercept, check that we got an interception event then
137177 // forward the HTLC on to nodes[2] and claim the payment.
138178 let intercept_id;
@@ -171,7 +211,14 @@ fn do_test_htlc_interception_flags(
171211 // If we were not set to intercept, check that the HTLC either failed or was
172212 // automatically forwarded as appropriate.
173213 match ( modification, flag) {
174- ( None , Flag :: ToOnlinePrivateChannels | Flag :: ToPublicChannels ) => {
214+ (
215+ None ,
216+ Flag :: ToOnlinePrivateChannels
217+ | Flag :: ToPublicChannels
218+ | Flag :: FromPrivateChannels
219+ | Flag :: FromPublicToPrivateChannels
220+ | Flag :: FromPublicToPublicChannels ,
221+ ) => {
175222 check_added_monitors ( & nodes[ 1 ] , 1 ) ;
176223
177224 let forward_ev = SendEvent :: from_node ( & nodes[ 1 ] ) ;
@@ -240,31 +287,55 @@ fn do_test_htlc_interception_flags(
240287}
241288
242289const MAX_BITMASK : u8 = HTLCInterceptionFlags :: AllValidHTLCs as u8 ;
243- const ALL_FLAGS : [ HTLCInterceptionFlags ; 5 ] = [
290+ const ALL_FLAGS : [ HTLCInterceptionFlags ; 8 ] = [
244291 HTLCInterceptionFlags :: ToInterceptSCIDs ,
245292 HTLCInterceptionFlags :: ToOfflinePrivateChannels ,
246293 HTLCInterceptionFlags :: ToOnlinePrivateChannels ,
247294 HTLCInterceptionFlags :: ToPublicChannels ,
248295 HTLCInterceptionFlags :: ToUnknownSCIDs ,
296+ HTLCInterceptionFlags :: FromPrivateChannels ,
297+ HTLCInterceptionFlags :: FromPublicToPrivateChannels ,
298+ HTLCInterceptionFlags :: FromPublicToPublicChannels ,
249299] ;
250-
251300#[ test]
252- fn test_htlc_interception_flags ( ) {
301+ fn check_all_flags ( ) {
253302 let mut all_flag_bits = 0 ;
254303 for flag in ALL_FLAGS {
255304 all_flag_bits |= flag as isize ;
256305 }
257306 assert_eq ! ( all_flag_bits, MAX_BITMASK as isize , "all flags must test all bits" ) ;
307+ }
258308
309+ fn test_htlc_interception_flags_subrange < I : Iterator < Item = u8 > > ( r : I ) {
259310 // Test all 2^5 = 32 combinations of the HTLCInterceptionFlags bitmask
260311 // For each combination, test 5 different HTLC forwards and verify correct interception behavior
261- for flags_bitmask in 0 ..= MAX_BITMASK {
312+ for flags_bitmask in r {
262313 for flag in ALL_FLAGS {
263314 do_test_htlc_interception_flags ( flags_bitmask, flag, None ) ;
264315 }
265316 }
266317}
267318
319+ #[ test]
320+ fn test_htlc_interception_flags_a ( ) {
321+ test_htlc_interception_flags_subrange ( 0 ..MAX_BITMASK / 4 ) ;
322+ }
323+
324+ #[ test]
325+ fn test_htlc_interception_flags_b ( ) {
326+ test_htlc_interception_flags_subrange ( MAX_BITMASK / 4 ..MAX_BITMASK / 2 ) ;
327+ }
328+
329+ #[ test]
330+ fn test_htlc_interception_flags_c ( ) {
331+ test_htlc_interception_flags_subrange ( MAX_BITMASK / 2 ..MAX_BITMASK / 4 * 3 ) ;
332+ }
333+
334+ #[ test]
335+ fn test_htlc_interception_flags_d ( ) {
336+ test_htlc_interception_flags_subrange ( MAX_BITMASK / 4 * 3 ..=MAX_BITMASK ) ;
337+ }
338+
268339#[ test]
269340fn test_htlc_bad_for_chan_config ( ) {
270341 // Test that interception won't be done if an HTLC fails to meet the target channel's channel
@@ -273,6 +344,9 @@ fn test_htlc_bad_for_chan_config() {
273344 HTLCInterceptionFlags :: ToOfflinePrivateChannels ,
274345 HTLCInterceptionFlags :: ToOnlinePrivateChannels ,
275346 HTLCInterceptionFlags :: ToPublicChannels ,
347+ HTLCInterceptionFlags :: FromPrivateChannels ,
348+ HTLCInterceptionFlags :: FromPublicToPrivateChannels ,
349+ HTLCInterceptionFlags :: FromPublicToPublicChannels ,
276350 ] ;
277351 for flag in have_chan_flags {
278352 do_test_htlc_interception_flags ( flag as u8 , flag, Some ( ForwardingMod :: FeeTooLow ) ) ;
0 commit comments