Skip to content

Commit

Permalink
Fix bugs in mathematical functions using binary circuits.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Feb 4, 2022
1 parent d50e97f commit 61d40b7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
10 changes: 9 additions & 1 deletion Compiler/GC/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ def n_bits(self):
def store_in_mem(self, address):
for i, x in enumerate(self.elements()):
x.store_in_mem(address + i)
def bit_decompose(self, n_bits=None, security=None):
def bit_decompose(self, n_bits=None, security=None, maybe_mixed=None):
return self.v[:n_bits]
bit_compose = from_vec
def reveal(self):
Expand Down Expand Up @@ -1267,6 +1267,9 @@ class sbitfixvec(_fix):
int_type = sbitintvec.get_type(sbitfix.k)
float_type = type(None)
clear_type = cbitfix
@property
def bit_type(self):
return type(self.v[0])
@classmethod
def set_precision(cls, f, k=None):
super(sbitfixvec, cls).set_precision(f=f, k=k)
Expand All @@ -1284,6 +1287,8 @@ def __init__(self, value=None, *args, **kwargs):
if isinstance(value, (list, tuple)):
self.v = self.int_type.from_vec(sbitvec([x.v for x in value]))
else:
if isinstance(value, sbitvec):
value = self.int_type(value)
super(sbitfixvec, self).__init__(value, *args, **kwargs)
def elements(self):
return [sbitfix._new(x, f=self.f, k=self.k) for x in self.v.elements()]
Expand All @@ -1293,9 +1298,12 @@ def mul(self, other):
else:
return super(sbitfixvec, self).mul(other)
def __xor__(self, other):
if util.is_zero(other):
return self
return self._new(self.v ^ other.v)
def __and__(self, other):
return self._new(self.v & other.v)
__rxor__ = __xor__
@staticmethod
def multipliable(other, k, f, size):
class cls(_fix):
Expand Down
10 changes: 5 additions & 5 deletions Compiler/mpc_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class my_fix(type(a)):
# how many bits to use from integer part
n_int_bits = int(math.ceil(math.log(a.k - a.f, 2)))
n_bits = a.f + n_int_bits
sint = types.sint
sint = a.int_type
if types.program.options.ring and not as19:
intbitint = types.intbitint
n_shift = int(types.program.options.ring) - a.k
Expand Down Expand Up @@ -367,17 +367,17 @@ class my_fix(type(a)):
bits = a.v.bit_decompose(a.k, maybe_mixed=True)
lower = sint.bit_compose(bits[:a.f])
higher_bits = bits[a.f:n_bits]
s = sint.conv(bits[-1])
s = a.bit_type.conv(bits[-1])
bits_to_check = bits[n_bits:-1]
if not as19:
c = types.sfix._new(lower, k=a.k, f=a.f)
c = a._new(lower, k=a.k, f=a.f)
assert(len(higher_bits) == n_bits - a.f)
pow2_bits = [sint.conv(x) for x in higher_bits]
d = floatingpoint.Pow2_from_bits(pow2_bits)
g = exp_from_parts(d, c)
small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits,
small_result = a._new(g.v.round(a.f + 2 ** n_int_bits,
2 ** n_int_bits, signed=False,
nearest=types.sfix.round_nearest),
nearest=a.round_nearest),
k=a.k, f=a.f)
if zero_output:
t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y),
Expand Down
1 change: 1 addition & 0 deletions Compiler/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4274,6 +4274,7 @@ class sfix(_fix):
:params _v: int/float/regint/cint/sint/sfloat
"""
int_type = sint
bit_type = sintbit
clear_type = cfix

@vectorized_classmethod
Expand Down

0 comments on commit 61d40b7

Please sign in to comment.