Skip to content

Commit

Permalink
Run cargo fmt, fix clippy (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattxwang committed Jul 13, 2023
1 parent a3ad51b commit 9c59632
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 31 deletions.
76 changes: 51 additions & 25 deletions src/bayesian_network.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! A graphical representation of a Bayesian network

use std::collections::{HashMap, BTreeMap};
use std::collections::{BTreeMap, HashMap};

use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};

/// maps each variable name to a CPT table
/// - rows are indexed by the current variable's possible values
Expand All @@ -22,6 +22,7 @@ use serde::{Serialize, Deserialize};
/// says Pr(c=c1 | a=a1, b=b1) = 0.1
/// Pr(c=c1 | a=a2, b=b1) = 0.4
/// Pr(c=c1 | a=a1, b=b3) = 0.3
#[allow(clippy::upper_case_acronyms)]
type CPT = HashMap<String, Vec<Vec<f64>>>;
/// maps each variable name to a list of that variable's possible values
type States = HashMap<String, Vec<String>>;
Expand All @@ -34,21 +35,32 @@ pub struct BayesianNetwork {
variables: Vec<String>,
cpts: CPT,
states: States,
parents: Parents
parents: Parents,
}

impl BayesianNetwork {
pub fn from_string(fname: &str) -> BayesianNetwork {
serde_json::from_str(&fname).unwrap()
serde_json::from_str(fname).unwrap()
}

fn get_state_index(&self, variable: &String, assignment: &String) -> usize {
let cur_s = self.states.get(variable).unwrap_or_else(|| panic!("could not find variable {variable}"));
cur_s.into_iter().position(|x| *x == *assignment).unwrap_or_else(|| panic!("could not find assignment {assignment} for variable {variable}"))
let cur_s = self
.states
.get(variable)
.unwrap_or_else(|| panic!("could not find variable {variable}"));
cur_s
.iter()
.position(|x| *x == *assignment)
.unwrap_or_else(|| {
panic!("could not find assignment {assignment} for variable {variable}")
})
}

fn get_num_states(&self, variable: &String) -> usize {
let cur_s = self.states.get(variable).unwrap_or_else(|| panic!("could not find variable {variable}"));
let cur_s = self
.states
.get(variable)
.unwrap_or_else(|| panic!("could not find variable {variable}"));
cur_s.len()
}

Expand All @@ -62,15 +74,15 @@ impl BayesianNetwork {
return vec![HashMap::new()];
}

let mut r : Vec<HashMap<String, String>> = Vec::new();
let mut r: Vec<HashMap<String, String>> = Vec::new();
let p = cur_parents.pop().unwrap();
let cur_values = self.get_all_assignments(&p);
let sub = self.parent_h(cur_parents);

// add each assignment onto
for v in cur_values {
for i in 0..(sub.len()) {
let mut new_s = sub[i].clone();
for sub_i in sub.iter() {
let mut new_s = sub_i.clone();
new_s.insert(p.clone(), v.clone());
r.push(new_s);
}
Expand All @@ -94,36 +106,46 @@ impl BayesianNetwork {
}

/// Get the conditional probability Pr(variable = variable_value | parent_assignment)
pub fn get_conditional_prob(&self, variable: &String, variable_value: &String, parent_assignment: &HashMap<String, String>) -> f64 {
pub fn get_conditional_prob(
&self,
variable: &String,
variable_value: &String,
parent_assignment: &HashMap<String, String>,
) -> f64 {
let var_idx = self.get_state_index(variable, variable_value);
let row = &self.cpts[variable][var_idx];
// compute the index into the row
let parents = self.parents.get(variable).unwrap();
let mut cur_stride = 1;
let mut idx = 0;
for parent in parents.into_iter().rev() {
for parent in parents.iter().rev() {
let parent_assgn = &parent_assignment[parent];
let parent_asggn_idx = self.get_state_index(parent, parent_assgn);
idx += cur_stride * parent_asggn_idx;
let parent_sz = self.get_num_states(parent);
cur_stride = cur_stride * parent_sz;
cur_stride *= parent_sz;
}
row[idx]
}

/// Produces a list of variables in topological order
pub fn topological_sort(&self) -> Vec<String> {
// super naive toposort
let mut result : Vec<String> = Vec::new();
let mut cur_vars: BTreeMap<String, Vec<String>> = self.variables.iter().map(|v| {
(v.clone(), self.parents[v].clone())
}).collect();
let mut result: Vec<String> = Vec::new();
let mut cur_vars: BTreeMap<String, Vec<String>> = self
.variables
.iter()
.map(|v| (v.clone(), self.parents[v].clone()))
.collect();

while !cur_vars.is_empty() {
// find a variable with no parents, remove it, add it to the result
// list, and remove it as a parent from all other nodes
let topvar : String = {
let (topvar, _) = cur_vars.iter().find(|(_, value)| { value.is_empty() }).unwrap_or_else(|| panic!("graph not topologically sortable"));
let topvar: String = {
let (topvar, _) = cur_vars
.iter()
.find(|(_, value)| value.is_empty())
.unwrap_or_else(|| panic!("graph not topologically sortable"));
topvar.clone()
};
result.push(topvar.clone());
Expand All @@ -136,19 +158,23 @@ impl BayesianNetwork {
}
}


#[test]
fn test_conditional() {
let sachs = include_str!("../bayesian_networks/sachs.json");
let network = BayesianNetwork::from_string(&sachs);
let parent_assgn = HashMap::from([ (String::from("Erk"), String::from("HIGH")),
(String::from("PKA"), String::from("AVG")) ]);
assert_eq!(network.get_conditional_prob(&String::from("Akt"), &String::from("LOW"), &parent_assgn),0.177105936);
let network = BayesianNetwork::from_string(sachs);
let parent_assgn = HashMap::from([
(String::from("Erk"), String::from("HIGH")),
(String::from("PKA"), String::from("AVG")),
]);
assert_eq!(
network.get_conditional_prob(&String::from("Akt"), &String::from("LOW"), &parent_assgn),
0.177105936
);
}

#[test]
fn test_parent() {
let sachs = include_str!("../bayesian_networks/sachs.json");
let network = BayesianNetwork::from_string(&sachs);
let network = BayesianNetwork::from_string(sachs);
println!("{:?}", network.parent_assignments(&String::from("Erk")));
}
11 changes: 5 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
extern crate serde;
extern crate clap;
extern crate serde;
extern crate serde_json;
use std::{fs};
use std::fs;

use clap::Parser;

/// Processes a JSON Bayesian network generated by `converter`
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
/// A JSON Bayesian network file
#[clap(short, long, value_parser)]
file: String,
/// A JSON Bayesian network file
#[clap(short, long, value_parser)]
file: String,
}


pub mod bayesian_network;
fn main() {
let args = Args::parse();
Expand Down

0 comments on commit 9c59632

Please sign in to comment.