Skip to content

Commit

Permalink
Merge pull request #153 from sr-gi/20231025_payment_range
Browse files Browse the repository at this point in the history
sim-lib: Allows amount and interval ranges in manual activity definition
  • Loading branch information
carlaKC authored May 20, 2024
2 parents c6ee6ad + f1a3275 commit a9aba5d
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 16 deletions.
24 changes: 13 additions & 11 deletions sim-lib/src/defined_activity.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::{DestinationGenerator, NodeInfo, PaymentGenerationError, PaymentGenerator};
use crate::{
DestinationGenerator, NodeInfo, PaymentGenerationError, PaymentGenerator, ValueOrRange,
};
use std::fmt;
use tokio::time::Duration;

Expand All @@ -7,17 +9,17 @@ pub struct DefinedPaymentActivity {
destination: NodeInfo,
start: Duration,
count: Option<u64>,
wait: Duration,
amount: u64,
wait: ValueOrRange<u16>,
amount: ValueOrRange<u64>,
}

impl DefinedPaymentActivity {
pub fn new(
destination: NodeInfo,
start: Duration,
count: Option<u64>,
wait: Duration,
amount: u64,
wait: ValueOrRange<u16>,
amount: ValueOrRange<u64>,
) -> Self {
DefinedPaymentActivity {
destination,
Expand All @@ -33,7 +35,7 @@ impl fmt::Display for DefinedPaymentActivity {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"static payment of {} to {} every {:?}",
"static payment of {} to {} every {}s",
self.amount, self.destination, self.wait
)
}
Expand All @@ -55,7 +57,7 @@ impl PaymentGenerator for DefinedPaymentActivity {
}

fn next_payment_wait(&self) -> Duration {
self.wait
Duration::from_secs(self.wait.value() as u64)
}

fn payment_amount(
Expand All @@ -67,17 +69,17 @@ impl PaymentGenerator for DefinedPaymentActivity {
"destination amount must not be set for defined activity generator".to_string(),
))
} else {
Ok(self.amount)
Ok(self.amount.value())
}
}
}

#[cfg(test)]
mod tests {
use super::DefinedPaymentActivity;
use super::*;
use crate::test_utils::{create_nodes, get_random_keypair};
use crate::{DestinationGenerator, PaymentGenerationError, PaymentGenerator};
use std::time::Duration;

#[test]
fn test_defined_activity_generator() {
Expand All @@ -91,8 +93,8 @@ mod tests {
node.clone(),
Duration::from_secs(0),
None,
Duration::from_secs(60),
payment_amt,
crate::ValueOrRange::Value(60),
crate::ValueOrRange::Value(payment_amt),
);

let (dest, dest_capacity) = generator.choose_destination(source.1);
Expand Down
54 changes: 49 additions & 5 deletions sim-lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use bitcoin::Network;
use csv::WriterBuilder;
use lightning::ln::features::NodeFeatures;
use lightning::ln::PaymentHash;
use rand::Rng;
use random_activity::RandomActivityError;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
Expand Down Expand Up @@ -129,6 +130,47 @@ pub struct SimParams {
pub activity: Vec<ActivityParser>,
}

/// Either a value or a range parsed from the simulation file.
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ValueOrRange<T> {
Value(T),
Range(T, T),
}

impl<T> ValueOrRange<T>
where
T: std::cmp::PartialOrd + rand_distr::uniform::SampleUniform + Copy,
{
/// Get the enclosed value. If value is defined as a range, sample from it uniformly at random.
pub fn value(&self) -> T {
match self {
ValueOrRange::Value(x) => *x,
ValueOrRange::Range(x, y) => {
let mut rng = rand::thread_rng();
rng.gen_range(*x..*y)
},
}
}
}

impl<T> Display for ValueOrRange<T>
where
T: Display,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
ValueOrRange::Value(x) => write!(f, "{x}"),
ValueOrRange::Range(x, y) => write!(f, "({x}-{y})"),
}
}
}

/// The payment amount in msat. Either a value or a range.
type Amount = ValueOrRange<u64>;
/// The interval of seconds between payments. Either a value or a range.
type Interval = ValueOrRange<u16>;

/// Data structure used to parse information from the simulation file. It allows source and destination to be
/// [NodeId], which enables the use of public keys and aliases in the simulation description.
#[derive(Debug, Clone, Serialize, Deserialize)]
Expand All @@ -146,9 +188,11 @@ pub struct ActivityParser {
#[serde(default)]
pub count: Option<u64>,
/// The interval of the event, as in every how many seconds the payment is performed.
pub interval_secs: u16,
#[serde(with = "serializers::serde_value_or_range")]
pub interval_secs: Interval,
/// The amount of m_sat to used in this payment.
pub amount_msat: u64,
#[serde(with = "serializers::serde_value_or_range")]
pub amount_msat: Amount,
}

/// Data structure used internally by the simulator. Both source and destination are represented as [PublicKey] here.
Expand All @@ -164,9 +208,9 @@ pub struct ActivityDefinition {
/// The number of payments to send over the course of the simulation.
pub count: Option<u64>,
/// The interval of the event, as in every how many seconds the payment is performed.
pub interval_secs: u16,
pub interval_secs: Interval,
/// The amount of m_sat to used in this payment.
pub amount_msat: u64,
pub amount_msat: Amount,
}

#[derive(Debug, Error)]
Expand Down Expand Up @@ -731,7 +775,7 @@ impl Simulation {
description.destination.clone(),
Duration::from_secs(description.start_secs.into()),
description.count,
Duration::from_secs(description.interval_secs.into()),
description.interval_secs,
description.amount_msat,
);

Expand Down
36 changes: 36 additions & 0 deletions sim-lib/src/serializers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,42 @@ pub mod serde_node_id {
}
}

pub mod serde_value_or_range {
use super::*;
use serde::de::Error;

use crate::ValueOrRange;

pub fn serialize<S, T>(x: &ValueOrRange<T>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
T: std::fmt::Display,
{
serializer.serialize_str(&match x {
ValueOrRange::Value(p) => p.to_string(),
ValueOrRange::Range(x, y) => format!("[{}, {}]", x, y),
})
}

pub fn deserialize<'de, D, T>(deserializer: D) -> Result<ValueOrRange<T>, D::Error>
where
D: serde::Deserializer<'de>,
T: serde::Deserialize<'de> + std::cmp::PartialOrd + std::fmt::Display + Copy,
{
let a = ValueOrRange::deserialize(deserializer)?;
if let ValueOrRange::Range(x, y) = a {
if x >= y {
return Err(D::Error::custom(format!(
"Cannot parse range. Ranges must be strictly increasing (i.e. [x, y] with x > y). Received [{}, {}]",
x, y
)));
}
}

Ok(a)
}
}

pub fn deserialize_path<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: serde::Deserializer<'de>,
Expand Down

0 comments on commit a9aba5d

Please sign in to comment.