Skip to content

Commit 56490ea

Browse files
pefontanaPedro Fontana
andauthored
Add ecdsa builtin (#140)
* Add Hash impl to Pyrelocatable * Add ecdsa Builtin * Add ecdsa.cairo * Get ecdsa_builtin from globals * Add mod ecdsa * Add __getattr__ for refenrences * Update ecdsa.cairo * Update ecdsa.cairo * Update ecdsa.cairo * Remove prints * Update ecdsa.cairo * Update hints_tests.py * Add error hundling to add_signature * Modify execute_hint flow * Update Cargo.toml * Update ecdsa.cairo * Remove unwrap() * Update hints_tests.py * ecdsa_builtin.borrow() * Exclude dict_read and dict_update.cairo tests * cargo clippy Co-authored-by: Pedro Fontana <[email protected]>
1 parent b4674ea commit 56490ea

File tree

11 files changed

+137
-15
lines changed

11 files changed

+137
-15
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ default = ["pyo3/num-bigint", "pyo3/auto-initialize"]
1313

1414
[dependencies]
1515
pyo3 = { version = "0.16.5" }
16-
cairo-rs = { git = "https://github.com/lambdaclass/cairo-rs.git", rev = "4f36aaf46dea8cac158d0da5e80537388e048c01" }
16+
cairo-rs = { git = "https://github.com/lambdaclass/cairo-rs.git", rev = "8c47dda53e874545895b34d675be6254878a9e7b" }
1717
num-bigint = "0.4"
1818
lazy_static = "1.4.0"
1919

cairo_programs/ecdsa.cairo

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
%builtins output pedersen ecdsa
2+
3+
from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin
4+
from starkware.cairo.common.hash import hash2
5+
from starkware.cairo.common.signature import verify_ecdsa_signature
6+
7+
func main{output_ptr : felt*, pedersen_ptr : HashBuiltin*, ecdsa_ptr : SignatureBuiltin*}():
8+
alloc_locals
9+
10+
let your_eth_addr = 874739451078007766457464989774322083649278607533249481151382481072868806602
11+
let signature_r = 1839793652349538280924927302501143912227271479439798783640887258675143576352
12+
let signature_s = 1819432147005223164874083361865404672584671743718628757598322238853218813979
13+
let msg = 0000000000000000000000000000000000000000000000000000000000000002
14+
15+
verify_ecdsa_signature(
16+
msg,
17+
your_eth_addr,
18+
signature_r,
19+
signature_s,
20+
)
21+
22+
23+
assert [output_ptr] = your_eth_addr
24+
let output_ptr = output_ptr + 1
25+
26+
return ()
27+
end

comparer_tracer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def new_runner(program_name: str):
88

99
if __name__ == "__main__":
1010
program_name = sys.argv[1]
11-
if program_name in ["blake2s_felt", "blake2s_finalize", "blake2s_integration_tests", "blake2s_hello_world_hash", "dict_squash", "squash_dict", "dict_write"]:
11+
if program_name in ["blake2s_felt", "blake2s_finalize", "blake2s_integration_tests", "blake2s_hello_world_hash", "dict_squash", "squash_dict", "dict_write", "dict_read", "dict_update"]:
1212
pass
1313
else:
1414
new_runner(program_name)

hints_tests.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def test_program(program_name: str):
3232
test_program("memcpy")
3333
test_program("memset")
3434
test_program("dict_new")
35-
test_program("dict_read")
35+
# test_program("dict_read") # Waiting on starkware PR
3636
# test_program("dict_write") # ValueError: Custom Hint Error: AttributeError: 'PyTypeId' object has no attribute 'segment_index'
37-
test_program("dict_update")
37+
# test_program("dict_update") # Waiting on starkware PR
3838
test_program("default_dict_new")
3939
# test_program("squash_dict") # ValueError: Custom Hint Error: ValueError: Failed to get ids value
4040
# test_program("dict_squash") # Custom Hint Error: AttributeError: 'PyTypeId' object has no attribute 'segment_index'
@@ -69,4 +69,5 @@ def test_program(program_name: str):
6969
test_program("blake2s_finalize")
7070
test_program("blake2s_felt")
7171
test_program("blake2s_integration_tests")
72+
test_program("ecdsa")
7273
print("\nAll test have passed")

src/ecdsa.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
use std::collections::HashMap;
2+
3+
use cairo_rs::{
4+
types::relocatable::Relocatable,
5+
vm::{errors::vm_errors::VirtualMachineError, runners::builtin_runner::SignatureBuiltinRunner},
6+
};
7+
8+
use num_bigint::BigInt;
9+
use pyo3::prelude::*;
10+
11+
use crate::relocatable::PyRelocatable;
12+
13+
#[pyclass(name = "Signature")]
14+
#[derive(Clone, Debug)]
15+
pub struct PySignature {
16+
signatures: HashMap<PyRelocatable, (BigInt, BigInt)>,
17+
}
18+
19+
#[pymethods]
20+
impl PySignature {
21+
#[new]
22+
pub fn new() -> Self {
23+
Self {
24+
signatures: HashMap::new(),
25+
}
26+
}
27+
28+
pub fn add_signature(&mut self, address: PyRelocatable, pair: (BigInt, BigInt)) {
29+
self.signatures.insert(address, pair);
30+
}
31+
}
32+
33+
impl PySignature {
34+
pub fn update_signature(
35+
&self,
36+
signature_builtin: &mut SignatureBuiltinRunner,
37+
) -> Result<(), VirtualMachineError> {
38+
for (address, pair) in self.signatures.iter() {
39+
signature_builtin
40+
.add_signature(Relocatable::from(address), pair)
41+
.map_err(VirtualMachineError::MemoryError)?
42+
}
43+
Ok(())
44+
}
45+
}
46+
47+
impl Default for PySignature {
48+
fn default() -> Self {
49+
Self::new()
50+
}
51+
}
52+
53+
impl ToPyObject for PySignature {
54+
fn to_object(&self, py: Python<'_>) -> PyObject {
55+
self.clone().into_py(py)
56+
}
57+
}

src/ids.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@ impl PyIds {
8484
.ok_or_else(|| to_py_error(IDS_GET_ERROR_MSG))?;
8585

8686
if let Some(cairo_type) = hint_ref.cairo_type.as_deref() {
87+
let chars = cairo_type.chars().rev();
88+
let clear_ref = chars
89+
.skip_while(|c| c == &'*')
90+
.collect::<String>()
91+
.chars()
92+
.rev()
93+
.collect::<String>();
94+
8795
if self.struct_types.contains_key(cairo_type) {
8896
return Ok(PyTypedId {
8997
vm: self.vm.clone(),
@@ -96,6 +104,24 @@ impl PyIds {
96104
struct_types: Rc::clone(&self.struct_types),
97105
}
98106
.into_py(py));
107+
} else if self.struct_types.contains_key(&clear_ref) {
108+
let addr =
109+
compute_addr_from_reference(hint_ref, &self.vm.borrow(), &self.ap_tracking)?;
110+
111+
let hint_value = self
112+
.vm
113+
.borrow()
114+
.get_relocatable(&addr)
115+
.map_err(to_py_error)?
116+
.into_owned();
117+
118+
return Ok(PyTypedId {
119+
vm: self.vm.clone(),
120+
hint_value,
121+
cairo_type: cairo_type.to_string(),
122+
struct_types: Rc::clone(&self.struct_types),
123+
}
124+
.into_py(py));
99125
}
100126
}
101127

@@ -156,11 +182,10 @@ struct PyTypedId {
156182
impl PyTypedId {
157183
#[getter]
158184
fn __getattr__(&self, py: Python, name: &str) -> PyResult<PyObject> {
159-
let struct_type = self.struct_types.get(&self.cairo_type).unwrap();
160-
161185
if name == "address_" {
162186
return Ok(PyMaybeRelocatable::from(self.hint_value.clone()).to_object(py));
163187
}
188+
let struct_type = self.struct_types.get(&self.cairo_type).unwrap();
164189

165190
match struct_type.get(name) {
166191
Some(member) => {

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
pub mod cairo_run;
22
pub mod cairo_runner;
3+
mod ecdsa;
34
pub mod ids;
45
mod memory;
56
mod memory_segments;

src/relocatable.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pub enum PyMaybeRelocatable {
1818
}
1919

2020
#[pyclass(name = "Relocatable")]
21-
#[derive(Clone, Debug, PartialEq, Eq)]
21+
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
2222
pub struct PyRelocatable {
2323
#[pyo3(get)]
2424
pub segment_index: isize,

src/vm_core.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::ecdsa::PySignature;
12
use crate::ids::PyIds;
23
use crate::pycell;
34
use crate::scope_manager::{PyEnterScope, PyExitScope};
@@ -23,7 +24,7 @@ use std::any::Any;
2324
use std::collections::HashMap;
2425
use std::{cell::RefCell, rc::Rc};
2526

26-
const GLOBAL_NAMES: [&str; 17] = [
27+
const GLOBAL_NAMES: [&str; 18] = [
2728
"memory",
2829
"segments",
2930
"ap",
@@ -33,6 +34,7 @@ const GLOBAL_NAMES: [&str; 17] = [
3334
"vm_exit_scope",
3435
"to_felt_or_relocatable",
3536
"range_check_builtin",
37+
"ecdsa_builtin",
3638
"PRIME",
3739
"__doc__",
3840
"__annotations__",
@@ -76,8 +78,8 @@ impl PyVM {
7678
Python::with_gil(|py| -> Result<(), VirtualMachineError> {
7779
let memory = PyMemory::new(self);
7880
let segments = PySegmentManager::new(self, memory.clone());
79-
let ap = PyRelocatable::from(self.vm.borrow().get_ap());
80-
let fp = PyRelocatable::from(self.vm.borrow().get_fp());
81+
let ap = PyRelocatable::from((*self.vm).borrow().get_ap());
82+
let fp = PyRelocatable::from((*self.vm).borrow().get_fp());
8183
let ids = PyIds::new(
8284
self,
8385
&hint_data.ids_data,
@@ -88,8 +90,9 @@ impl PyVM {
8890
let enter_scope = pycell!(py, PyEnterScope::new());
8991
let exit_scope = pycell!(py, PyExitScope::new());
9092
let range_check_builtin =
91-
PyRangeCheck::from(self.vm.borrow().get_range_check_builtin());
92-
let prime = self.vm.borrow().get_prime().clone();
93+
PyRangeCheck::from((*self.vm).borrow().get_range_check_builtin());
94+
let ecdsa_builtin = pycell!(py, PySignature::new());
95+
let prime = (*self.vm).borrow().get_prime().clone();
9396
let to_felt_or_relocatable = ToFeltOrRelocatableFunc;
9497

9598
// This line imports Python builtins. If not imported, this will run only with Python 3.10
@@ -126,6 +129,9 @@ impl PyVM {
126129
globals
127130
.set_item("range_check_builtin", range_check_builtin)
128131
.map_err(to_vm_error)?;
132+
globals
133+
.set_item("ecdsa_builtin", ecdsa_builtin)
134+
.map_err(to_vm_error)?;
129135
globals.set_item("PRIME", prime).map_err(to_vm_error)?;
130136
globals
131137
.set_item(
@@ -155,6 +161,11 @@ impl PyVM {
155161
py,
156162
);
157163

164+
if self.vm.borrow_mut().get_signature_builtin().is_ok() {
165+
ecdsa_builtin
166+
.borrow()
167+
.update_signature(self.vm.borrow_mut().get_signature_builtin()?)?;
168+
}
158169
enter_scope.borrow().update_scopes(exec_scopes)?;
159170
exit_scope.borrow().update_scopes(exec_scopes)
160171
})?;
@@ -171,7 +182,7 @@ impl PyVM {
171182
struct_types: Rc<HashMap<String, HashMap<String, Member>>>,
172183
constants: &HashMap<String, BigInt>,
173184
) -> Result<(), VirtualMachineError> {
174-
let pc_offset = self.vm.borrow().get_pc().offset;
185+
let pc_offset = (*self.vm).borrow().get_pc().offset;
175186

176187
if let Some(hint_list) = hint_data_dictionary.get(&pc_offset) {
177188
for hint_data in hint_list.iter() {

tests/compare_vm_state.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ for file in $(ls $tests_path | grep .cairo$ | sed -E 's/\.cairo$//'); do
2222
path_file="$tests_path/$file"
2323

2424
echo "$file"
25-
if ! ([ "$file" = "blake2s_felt" ] || [ "$file" = "blake2s_finalize" ] || [ "$file" = "blake2s_integration_tests" ] || [ "$file" = "blake2s_hello_world_hash" ] || [ "$file" = "dict_squash" ] || [ "$file" = "squash_dict" ] || [ "$file" = "dict_write" ]); then
25+
if ! ([ "$file" = "blake2s_felt" ] || [ "$file" = "blake2s_finalize" ] || [ "$file" = "blake2s_integration_tests" ] || [ "$file" = "blake2s_hello_world_hash" ] || [ "$file" = "dict_squash" ] || [ "$file" = "squash_dict" ] || [ "$file" = "dict_write" ] || [ "$file" = "dict_write" ] || [ "$file" = "dict_update" ] || [ "$file" = "dict_read" ]); then
2626
if $trace; then
2727
if ! diff -q $path_file.trace $path_file.rs.trace; then
2828
echo "Traces for $file differ"

0 commit comments

Comments
 (0)