diff --git a/cairo_programs/squash_dict.cairo b/cairo_programs/squash_dict.cairo index bfe4ccfb..f8c25f42 100644 --- a/cairo_programs/squash_dict.cairo +++ b/cairo_programs/squash_dict.cairo @@ -64,6 +64,7 @@ func squash_dict{range_check_ptr}( squashed_dict=squashed_dict, big_keys=big_keys) %{ + #TEST vm_exit_scope() %} return (squashed_dict=squashed_dict) @@ -173,7 +174,10 @@ func squash_dict_inner( [ap] = dict_accesses_end_minus1 - cast(last_loop_locals.access_ptr, felt) [ap] = [last_loop_locals.range_check_ptr]; ap++ tempvar n_used_accesses = last_loop_locals.range_check_ptr - range_check_ptr - %{ assert ids.n_used_accesses == len(access_indices[key]) %} + %{ + #TEST + assert ids.n_used_accesses == len(access_indices[key]) + %} # Write last value to dict_diff. last_loop_locals.value = dict_diff.new_value diff --git a/hints_tests.py b/hints_tests.py index 0f9d1868..6a77380a 100644 --- a/hints_tests.py +++ b/hints_tests.py @@ -32,11 +32,11 @@ def test_program(program_name: str): test_program("memcpy") test_program("memset") test_program("dict_new") - # test_program("dict_read") # Waiting on starkware PR - # test_program("dict_write") # ValueError: Custom Hint Error: AttributeError: 'PyTypeId' object has no attribute 'segment_index' - # test_program("dict_update") # Waiting on starkware PR + test_program("dict_read") + test_program("dict_write") + test_program("dict_update") test_program("default_dict_new") - # test_program("squash_dict") # ValueError: Custom Hint Error: ValueError: Failed to get ids value + test_program("squash_dict") # test_program("dict_squash") # Custom Hint Error: AttributeError: 'PyTypeId' object has no attribute 'segment_index' test_program("ids_size") test_program("split_felt") diff --git a/src/cairo_run.rs b/src/cairo_run.rs index 37882435..d97fe0bb 100644 --- a/src/cairo_run.rs +++ b/src/cairo_run.rs @@ -1,6 +1,7 @@ #[cfg(test)] mod test { use crate::cairo_runner::PyCairoRunner; + use pyo3::Python; use std::fs; #[test] @@ -9,9 +10,11 @@ mod test { let program = fs::read_to_string(path).unwrap(); let mut runner = PyCairoRunner::new(program, Some("main".to_string()), None, false).unwrap(); - runner - .cairo_run_py(false, None, None, None, None, None) - .expect("Couldn't run program"); + Python::with_gil(|py| { + runner + .cairo_run_py(false, None, None, None, py, None, None) + .expect("Couldn't run program"); + }); } #[test] @@ -25,9 +28,11 @@ mod test { false, ) .unwrap(); - runner - .cairo_run_py(false, None, None, None, None, None) - .expect("Couldn't run program"); + Python::with_gil(|py| { + runner + .cairo_run_py(false, None, None, None, py, None, None) + .expect("Couldn't run program"); + }); } #[test] @@ -36,8 +41,10 @@ mod test { let program = fs::read_to_string(path).unwrap(); let mut runner = PyCairoRunner::new(program, Some("main".to_string()), None, false).unwrap(); - runner - .cairo_run_py(false, None, None, None, None, None) - .expect("Couldn't run program"); + Python::with_gil(|py| { + runner + .cairo_run_py(false, None, None, None, py, None, None) + .expect("Couldn't run program"); + }); } } diff --git a/src/cairo_runner.rs b/src/cairo_runner.rs index ce4e70ad..044315fe 100644 --- a/src/cairo_runner.rs +++ b/src/cairo_runner.rs @@ -1,4 +1,5 @@ use crate::{ + dict_manager::PyDictManager, relocatable::{PyMaybeRelocatable, PyRelocatable}, utils::to_py_error, vm_core::PyVM, @@ -77,12 +78,14 @@ impl PyCairoRunner { } #[pyo3(name = "cairo_run")] + #[allow(clippy::too_many_arguments)] pub fn cairo_run_py( &mut self, print_output: bool, trace_file: Option<&str>, memory_file: Option<&str>, hint_locals: Option>, + py: Python, static_locals: Option>, entrypoint: Option<&str>, ) -> PyResult<()> { @@ -97,6 +100,10 @@ impl PyCairoRunner { if let Some(locals) = hint_locals { self.hint_locals = locals } + self.hint_locals.insert( + "__dict_manager".to_string(), + PyDictManager::new().into_py(py), + ); self.pyvm.static_locals = static_locals; @@ -651,9 +658,11 @@ mod test { ) .unwrap(); - runner - .cairo_run_py(false, None, None, None, None, None) - .unwrap(); + Python::with_gil(|py| { + runner + .cairo_run_py(false, None, None, None, py, None, None) + .unwrap(); + }); let new_segment = runner.add_segment(); assert_eq!( new_segment, @@ -684,9 +693,11 @@ mod test { ) .unwrap(); - runner - .cairo_run_py(false, None, None, None, None, None) - .unwrap(); + Python::with_gil(|py| { + runner + .cairo_run_py(false, None, None, None, py, None, None) + .unwrap(); + }); let expected_output: Vec = vec![RelocatableValue(PyRelocatable { segment_index: 2, @@ -716,22 +727,22 @@ mod test { ) .unwrap(); - runner - .cairo_run_py(false, None, None, None, None, None) - .unwrap(); + Python::with_gil(|py| { + runner + .cairo_run_py(false, None, None, None, py, None, None) + .unwrap(); - let expected_output: Vec = vec![ - RelocatableValue(PyRelocatable { - segment_index: 2, - offset: 0, - }), - RelocatableValue(PyRelocatable { - segment_index: 3, - offset: 0, - }), - ]; + let expected_output: Vec = vec![ + RelocatableValue(PyRelocatable { + segment_index: 2, + offset: 0, + }), + RelocatableValue(PyRelocatable { + segment_index: 3, + offset: 0, + }), + ]; - Python::with_gil(|py| { assert_eq!( runner .get_program_builtins_initial_stack(py) @@ -754,9 +765,11 @@ mod test { ) .unwrap(); - runner - .cairo_run_py(false, None, None, None, None, None) - .unwrap(); + Python::with_gil(|py| { + runner + .cairo_run_py(false, None, None, None, py, None, None) + .unwrap(); + }); let expected_output = PyRelocatable::from((1, 8)); @@ -778,9 +791,11 @@ mod test { false, ) .unwrap(); - runner - .cairo_run_py(false, None, None, None, None, Some("main")) - .unwrap(); + Python::with_gil(|py| { + runner + .cairo_run_py(false, None, None, None, py, None, Some("main")) + .unwrap(); + }); // Make a copy of the builtin in order to insert a second "fake" one // BuiltinRunner api is private, so we can create a new one for this test let fake_builtin = (*runner.pyvm.vm).borrow_mut().get_builtin_runners_as_mut()[0] @@ -821,9 +836,11 @@ mod test { ) .unwrap(); - runner - .cairo_run_py(false, None, None, None, None, None) - .unwrap(); + Python::with_gil(|py| { + runner + .cairo_run_py(false, None, None, None, py, None, None) + .unwrap(); + }); let expected_output = PyRelocatable::from((1, 0)); @@ -845,9 +862,11 @@ mod test { ) .unwrap(); - runner - .cairo_run_py(false, None, None, None, None, None) - .unwrap(); + Python::with_gil(|py| { + runner + .cairo_run_py(false, None, None, None, py, None, None) + .unwrap(); + }); // Make a copy of the builtin in order to insert a second "fake" one // BuiltinRunner api is private, so we can create a new one for this test @@ -883,9 +902,11 @@ mod test { ) .unwrap(); - runner - .cairo_run_py(false, None, None, None, None, None) - .unwrap(); + Python::with_gil(|py| { + runner + .cairo_run_py(false, None, None, None, py, None, None) + .unwrap(); + }); assert_eq!(runner.pyvm.vm.borrow().get_ap(), Relocatable::from((1, 41))); assert_eq!( @@ -924,10 +945,10 @@ mod test { let program = fs::read_to_string(path).unwrap(); let mut runner = PyCairoRunner::new(program, Some("main".to_string()), None, false).unwrap(); - runner - .cairo_run_py(false, None, None, None, None, None) - .unwrap(); Python::with_gil(|py| { + runner + .cairo_run_py(false, None, None, None, py, None, None) + .unwrap(); assert_eq!( 24, runner @@ -945,10 +966,10 @@ mod test { let program = fs::read_to_string(path).unwrap(); let mut runner = PyCairoRunner::new(program, Some("main".to_string()), None, false).unwrap(); - runner - .cairo_run_py(false, None, None, None, None, None) - .unwrap(); Python::with_gil(|py| { + runner + .cairo_run_py(false, None, None, None, py, None, None) + .unwrap(); assert_eq!( 0, runner @@ -1161,9 +1182,11 @@ mod test { false, ) .unwrap(); - runner - .cairo_run_py(false, None, None, None, None, None) - .unwrap(); + Python::with_gil(|py| { + runner + .cairo_run_py(false, None, None, None, py, None, None) + .unwrap(); + }); assert_eq! { PyRelocatable::from((1,2)), runner.get_initial_fp().unwrap() @@ -1370,19 +1393,22 @@ mod test { false, ) .unwrap(); - assert!(runner - .cairo_run_py( - false, - None, - None, - None, - Some(HashMap::from([( - "__find_element_max_size".to_string(), - Python::with_gil(|py| -> PyObject { 100.to_object(py) }), - )])), - None, - ) - .is_ok()); + Python::with_gil(|py| { + assert!(runner + .cairo_run_py( + false, + None, + None, + None, + py, + Some(HashMap::from([( + "__find_element_max_size".to_string(), + 100.to_object(py) + ),])), + None + ) + .is_ok()); + }); } #[test] @@ -1396,19 +1422,22 @@ mod test { false, ) .unwrap(); - assert!(runner - .cairo_run_py( - false, - None, - None, - None, - Some(HashMap::from([( - "__find_element_max_size".to_string(), - Python::with_gil(|py| -> PyObject { 1.to_object(py) }), - )])), - None - ) - .is_err()); + Python::with_gil(|py| { + assert!(runner + .cairo_run_py( + false, + None, + None, + None, + py, + Some(HashMap::from([( + "__find_element_max_size".to_string(), + 1.to_object(py) + ),])), + None + ) + .is_err()); + }); } #[test] @@ -1418,9 +1447,11 @@ mod test { let mut runner = PyCairoRunner::new(program, None, Some("small".to_string()), false).unwrap(); - runner - .cairo_run_py(false, None, None, None, None, Some("main")) - .expect("Call to PyCairoRunner::cairo_run_py() failed."); + Python::with_gil(|py| { + runner + .cairo_run_py(false, None, None, None, py, None, Some("main")) + .expect("Call to PyCairoRunner::cairo_run_py() failed."); + }); } /// Test that `PyCairoRunner::get()` works as intended. @@ -1437,7 +1468,7 @@ mod test { .unwrap(); runner - .cairo_run_py(false, None, None, None, None, None) + .cairo_run_py(false, None, None, None, py, None, None) .expect("Call to PyCairoRunner::cairo_run_py"); let mut ap = runner.get_ap().unwrap(); diff --git a/src/dict_manager.rs b/src/dict_manager.rs new file mode 100644 index 00000000..dfd9e5a0 --- /dev/null +++ b/src/dict_manager.rs @@ -0,0 +1,161 @@ +use crate::{ + ids::PyTypedId, + memory_segments::PySegmentManager, + relocatable::{PyMaybeRelocatable, PyRelocatable}, + utils::to_py_error, +}; +use cairo_rs::{ + hint_processor::builtin_hint_processor::dict_manager::DictManager, + types::relocatable::Relocatable, +}; +use num_bigint::BigInt; +use pyo3::{exceptions::PyKeyError, prelude::*}; + +use std::{cell::RefCell, collections::HashMap, rc::Rc}; + +#[pyclass(unsendable)] +pub struct PyDictManager { + manager: Rc>, +} + +#[pyclass(unsendable)] +pub struct PyDictTracker { + manager: Rc>, + key: Relocatable, +} + +impl Default for PyDictManager { + fn default() -> Self { + PyDictManager::new() + } +} + +#[pymethods] +impl PyDictManager { + #[new] + pub fn new() -> Self { + PyDictManager { + manager: Rc::new(RefCell::new(DictManager::new())), + } + } + + pub fn new_dict( + &self, + segments: &mut PySegmentManager, + initial_dict: HashMap, + py: Python, + ) -> PyResult { + Ok(PyMaybeRelocatable::from( + self.manager + .borrow_mut() + .new_dict(&mut segments.vm.borrow_mut(), initial_dict) + .map_err(to_py_error)?, + ) + .to_object(py)) + } + + pub fn new_default_dict( + &mut self, + segments: &mut PySegmentManager, + default_value: BigInt, + initial_dict: Option>, + py: Python, + ) -> PyResult { + Ok(PyMaybeRelocatable::from( + self.manager + .borrow_mut() + .new_default_dict(&mut segments.vm.borrow_mut(), &default_value, initial_dict) + .map_err(to_py_error)?, + ) + .to_object(py)) + } + + pub fn get_tracker(&mut self, dict_ptr: &PyTypedId) -> PyResult { + let ptr_addr = dict_ptr.hint_value.clone(); + self.manager + .borrow() + .get_tracker(&ptr_addr) + .map_err(to_py_error)?; + Ok(PyDictTracker { + manager: self.manager.clone(), + key: ptr_addr, + }) + } +} + +#[pymethods] +impl PyDictTracker { + #[getter] + pub fn get_current_ptr(&self, py: Python) -> PyResult { + Ok(PyRelocatable::from( + self.manager + .borrow_mut() + .get_tracker_mut(&self.key) + .map_err(to_py_error)? + .current_ptr + .clone(), + ) + .into_py(py)) + } + + #[getter] + pub fn get_data(&self, py: Python) -> PyObject { + PyDictTracker { + manager: self.manager.clone(), + key: self.key.clone(), + } + .into_py(py) + } + + #[setter] + pub fn set_current_ptr(&mut self, val: PyRelocatable) -> PyResult<()> { + self.manager + .borrow_mut() + .get_tracker_mut(&self.key) + .map_err(to_py_error)? + .current_ptr + .offset = val.offset; + self.key = Relocatable { + segment_index: val.segment_index, + offset: val.offset, + }; + Ok(()) + } + + #[getter] + pub fn __getitem__(&self, key: PyMaybeRelocatable, py: Python) -> PyResult { + match key { + PyMaybeRelocatable::Int(key) => Ok(PyMaybeRelocatable::from( + self.manager + .borrow_mut() + .get_tracker_mut(&self.key) + .map_err(to_py_error)? + .get_value(&key) + .map_err(to_py_error)?, + ) + .to_object(py)), + PyMaybeRelocatable::RelocatableValue(_) => Err(PyKeyError::new_err(key.to_object(py))), + } + } + + #[setter] + pub fn __setitem__( + &mut self, + key: PyMaybeRelocatable, + val: PyMaybeRelocatable, + py: Python, + ) -> PyResult<()> { + match (&key, &val) { + (PyMaybeRelocatable::Int(key), PyMaybeRelocatable::Int(val)) => { + self.manager + .borrow_mut() + .get_tracker_mut(&self.key) + .map_err(to_py_error)? + .insert_value(key, val); + + Ok(()) + } + _ => Err(PyKeyError::new_err(key.to_object(py))), + } + } +} diff --git a/src/ids.rs b/src/ids.rs index 9488a139..e3814752 100644 --- a/src/ids.rs +++ b/src/ids.rs @@ -84,14 +84,6 @@ impl PyIds { .ok_or_else(|| to_py_error(IDS_GET_ERROR_MSG))?; if let Some(cairo_type) = hint_ref.cairo_type.as_deref() { - let chars = cairo_type.chars().rev(); - let clear_ref = chars - .skip_while(|c| c == &'*') - .collect::() - .chars() - .rev() - .collect::(); - if self.struct_types.contains_key(cairo_type) { return Ok(PyTypedId { vm: self.vm.clone(), @@ -104,7 +96,10 @@ impl PyIds { struct_types: Rc::clone(&self.struct_types), } .into_py(py)); - } else if self.struct_types.contains_key(&clear_ref) { + } else if self + .struct_types + .contains_key(cairo_type.trim_end_matches('*')) + { let addr = compute_addr_from_reference(hint_ref, &self.vm.borrow(), &self.ap_tracking)?; @@ -118,7 +113,7 @@ impl PyIds { return Ok(PyTypedId { vm: self.vm.clone(), hint_value, - cairo_type: cairo_type.to_string(), + cairo_type: cairo_type.trim_end_matches('*').to_string(), struct_types: Rc::clone(&self.struct_types), } .into_py(py)); @@ -171,9 +166,9 @@ struct CairoStruct { } #[pyclass(unsendable)] -struct PyTypedId { +pub(crate) struct PyTypedId { vm: Rc>, - hint_value: Relocatable, + pub hint_value: Relocatable, cairo_type: String, struct_types: Rc>>, } diff --git a/src/lib.rs b/src/lib.rs index 204dcdfd..a4d3e45b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod cairo_run; pub mod cairo_runner; +mod dict_manager; mod ecdsa; pub mod ids; mod memory; diff --git a/src/memory_segments.rs b/src/memory_segments.rs index 353f97ec..356c24b6 100644 --- a/src/memory_segments.rs +++ b/src/memory_segments.rs @@ -10,7 +10,7 @@ use std::{cell::RefCell, rc::Rc}; #[pyclass(name = "MemorySegmentManager", unsendable)] pub struct PySegmentManager { - vm: Rc>, + pub(crate) vm: Rc>, #[pyo3(get)] memory: PyMemory, } diff --git a/src/relocatable.rs b/src/relocatable.rs index a339913f..33eb1ab6 100644 --- a/src/relocatable.rs +++ b/src/relocatable.rs @@ -182,6 +182,15 @@ impl From for PyMaybeRelocatable { } } +impl From<&Relocatable> for PyMaybeRelocatable { + fn from(val: &Relocatable) -> Self { + PyMaybeRelocatable::RelocatableValue(PyRelocatable { + segment_index: val.segment_index, + offset: val.offset, + }) + } +} + impl From for PyMaybeRelocatable { fn from(val: PyRelocatable) -> Self { PyMaybeRelocatable::RelocatableValue(val)