Skip to content

Commit ff629d2

Browse files
authored
Merge pull request #15 from simple-crypto/develop
Progress bars and non-blocking
2 parents 31b0e32 + 9a3344c commit ff629d2

File tree

6 files changed

+60
-28
lines changed

6 files changed

+60
-28
lines changed

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
# Ensure these are present (in case we are not using PEP-518 compatible build
88
# system).
99
import setuptools_scm
10-
import toml
1110

1211
scalib_features = ["pyo3/abi3"]
1312

src/scalib/attacks/sascagraph.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,15 @@ def set_table(self, table, values):
223223
)
224224
self.tables_[table] = values
225225

226-
def run_bp(self, it):
226+
def run_bp(self, it, progress=False):
227227
r"""Runs belief propagation algorithm on the current state of the graph.
228228
229229
Parameters
230230
----------
231231
it : int
232232
Number of iterations of belief propagation.
233+
progress: bool
234+
Show a progress bar (default: False).
233235
"""
234236
if self.solved_:
235237
raise Exception("Cannot run bp twice on a graph.")
@@ -241,6 +243,7 @@ def run_bp(self, it):
241243
self.edge_,
242244
self.nc_,
243245
self.n_,
246+
progress,
244247
)
245248
self.solved_ = True
246249

src/scalib_ext/scalib-py/src/belief_propagation.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ pub fn run_bp(
8989
nc: usize,
9090
// number of copies in the graph (n_runs)
9191
n: usize,
92+
// show a progress bar
93+
progress: bool,
9294
) -> PyResult<()> {
9395
// map all python functions to rust ones + generate the mapping in vec_functs_id
9496
let functions_rust: Vec<Func> = functions
@@ -104,8 +106,18 @@ pub fn run_bp(
104106
.map(|x| to_var(x.downcast::<PyDict>().unwrap()))
105107
.collect();
106108

107-
scalib::belief_propagation::run_bp(&functions_rust, &mut variables_rust, it, edge, nc, n)
109+
py.allow_threads(|| {
110+
scalib::belief_propagation::run_bp(
111+
&functions_rust,
112+
&mut variables_rust,
113+
it,
114+
edge,
115+
nc,
116+
n,
117+
progress,
118+
)
108119
.unwrap();
120+
});
109121

110122
variables_rust
111123
.iter()

src/scalib_ext/scalib-py/src/lda.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ impl LDA {
9696
/// x : traces with shape (n,ns)
9797
/// return prs with shape (n,nc). Every row corresponds to one probability distribution
9898
fn predict_proba<'py>(
99-
&mut self,
99+
&self,
100100
py: Python<'py>,
101101
x: PyReadonlyArray2<i16>,
102102
) -> PyResult<&'py PyArray2<f64>> {

src/scalib_ext/scalib-py/src/lib.rs

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ fn _scalib_ext(_py: Python, m: &PyModule) -> PyResult<()> {
4343
vertex: usize,
4444
nc: usize,
4545
n: usize,
46+
progress: bool,
4647
) -> PyResult<()> {
47-
belief_propagation::run_bp(py, functions, variables, it, vertex, nc, n)
48+
belief_propagation::run_bp(py, functions, variables, it, vertex, nc, n, progress)
4849
}
4950

5051
#[pyfn(m, "partial_cp")]
@@ -69,39 +70,45 @@ fn _scalib_ext(_py: Python, m: &PyModule) -> PyResult<()> {
6970

7071
#[pyfn(m, "rank_accuracy")]
7172
fn rank_accuracy(
73+
py: Python,
7274
costs: Vec<Vec<f64>>,
7375
key: Vec<usize>,
7476
acc: f64,
7577
merge: Option<usize>,
7678
method: String,
7779
max_nb_bin: usize,
7880
) -> PyResult<(f64, f64, f64)> {
79-
let res = str2method(&method).unwrap_or_else(|s| panic!("{}", s));
80-
let res = res.rank_accuracy(&costs, &key, acc, merge, max_nb_bin);
81-
match res {
82-
Ok(res) => Ok((res.min, res.est, res.max)),
83-
Err(s) => {
84-
panic!("{}", s);
81+
py.allow_threads(|| {
82+
let res = str2method(&method).unwrap_or_else(|s| panic!("{}", s));
83+
let res = res.rank_accuracy(&costs, &key, acc, merge, max_nb_bin);
84+
match res {
85+
Ok(res) => Ok((res.min, res.est, res.max)),
86+
Err(s) => {
87+
panic!("{}", s);
88+
}
8589
}
86-
}
90+
})
8791
}
8892

8993
#[pyfn(m, "rank_nbin")]
9094
fn rank_nbin(
95+
py: Python,
9196
costs: Vec<Vec<f64>>,
9297
key: Vec<usize>,
9398
nb_bin: usize,
9499
merge: Option<usize>,
95100
method: String,
96101
) -> PyResult<(f64, f64, f64)> {
97-
let res = str2method(&method).unwrap_or_else(|s| panic!("{}", s));
98-
let res = res.rank_nbin(&costs, &key, nb_bin, merge);
99-
match res {
100-
Ok(res) => Ok((res.min, res.est, res.max)),
101-
Err(s) => {
102-
panic!("{}", s);
102+
py.allow_threads(|| {
103+
let res = str2method(&method).unwrap_or_else(|s| panic!("{}", s));
104+
let res = res.rank_nbin(&costs, &key, nb_bin, merge);
105+
match res {
106+
Ok(res) => Ok((res.min, res.est, res.max)),
107+
Err(s) => {
108+
panic!("{}", s);
109+
}
103110
}
104-
}
111+
})
105112
}
106113

107114
Ok(())

src/scalib_ext/scalib/src/belief_propagation.rs

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ pub fn run_bp(
351351
nc: usize,
352352
// number of copies in the graph (n_runs)
353353
n: usize,
354+
// show a progress bar
355+
progress: bool,
354356
) -> Result<(), ()> {
355357
// Scratch array containing all the edge's messages.
356358
let mut edges: Vec<Array2<f64>> = vec![Array2::<f64>::ones((n, nc)); edge];
@@ -386,15 +388,7 @@ pub fn run_bp(
386388
}
387389
}
388390

389-
// loading bar
390-
let pb = ProgressBar::new(it as u64);
391-
pb.set_style(ProgressStyle::default_spinner().template(
392-
"{msg} {spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] ({pos}/{len}, ETA {eta})",
393-
)
394-
.on_finish(ProgressFinish::AndClear));
395-
pb.set_message("Calculating BP...");
396-
397-
for _ in (0..it).progress_with(pb) {
391+
let mut bp_iter = || {
398392
// This is a technique for runtime borrow-checking: we take reference on all the edges
399393
// at once, put them into options, then extract the references out of the options, one
400394
// at a time and out-of-order.
@@ -422,6 +416,23 @@ pub fn run_bp(
422416
})
423417
.collect();
424418
update_variables(&mut edge_for_var, variables);
419+
};
420+
421+
if progress {
422+
// loading bar
423+
let pb = ProgressBar::new(it as u64);
424+
pb.set_style(ProgressStyle::default_spinner().template(
425+
"{msg} {spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] ({pos}/{len}, ETA {eta})",
426+
)
427+
.on_finish(ProgressFinish::AndClear));
428+
pb.set_message("Calculating BP...");
429+
for _ in (0..it).progress_with(pb) {
430+
bp_iter();
431+
}
432+
} else {
433+
for _ in 0..it {
434+
bp_iter();
435+
}
425436
}
426437

427438
Ok(())

0 commit comments

Comments
 (0)