diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 94d520825..13619c7f9 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -811,7 +811,7 @@ def n_bits(self): def store_in_mem(self, address): for i, x in enumerate(self.elements()): x.store_in_mem(address + i) - def bit_decompose(self, n_bits=None, security=None): + def bit_decompose(self, n_bits=None, security=None, maybe_mixed=None): return self.v[:n_bits] bit_compose = from_vec def reveal(self): @@ -1267,6 +1267,9 @@ class sbitfixvec(_fix): int_type = sbitintvec.get_type(sbitfix.k) float_type = type(None) clear_type = cbitfix + @property + def bit_type(self): + return type(self.v[0]) @classmethod def set_precision(cls, f, k=None): super(sbitfixvec, cls).set_precision(f=f, k=k) @@ -1284,6 +1287,8 @@ def __init__(self, value=None, *args, **kwargs): if isinstance(value, (list, tuple)): self.v = self.int_type.from_vec(sbitvec([x.v for x in value])) else: + if isinstance(value, sbitvec): + value = self.int_type(value) super(sbitfixvec, self).__init__(value, *args, **kwargs) def elements(self): return [sbitfix._new(x, f=self.f, k=self.k) for x in self.v.elements()] @@ -1293,9 +1298,12 @@ def mul(self, other): else: return super(sbitfixvec, self).mul(other) def __xor__(self, other): + if util.is_zero(other): + return self return self._new(self.v ^ other.v) def __and__(self, other): return self._new(self.v & other.v) + __rxor__ = __xor__ @staticmethod def multipliable(other, k, f, size): class cls(_fix): diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index 322989b34..47253dc43 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -290,7 +290,7 @@ class my_fix(type(a)): # how many bits to use from integer part n_int_bits = int(math.ceil(math.log(a.k - a.f, 2))) n_bits = a.f + n_int_bits - sint = types.sint + sint = a.int_type if types.program.options.ring and not as19: intbitint = types.intbitint n_shift = int(types.program.options.ring) - a.k @@ -367,17 +367,17 @@ class my_fix(type(a)): bits = a.v.bit_decompose(a.k, maybe_mixed=True) lower = sint.bit_compose(bits[:a.f]) higher_bits = bits[a.f:n_bits] - s = sint.conv(bits[-1]) + s = a.bit_type.conv(bits[-1]) bits_to_check = bits[n_bits:-1] if not as19: - c = types.sfix._new(lower, k=a.k, f=a.f) + c = a._new(lower, k=a.k, f=a.f) assert(len(higher_bits) == n_bits - a.f) pow2_bits = [sint.conv(x) for x in higher_bits] d = floatingpoint.Pow2_from_bits(pow2_bits) g = exp_from_parts(d, c) - small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits, + small_result = a._new(g.v.round(a.f + 2 ** n_int_bits, 2 ** n_int_bits, signed=False, - nearest=types.sfix.round_nearest), + nearest=a.round_nearest), k=a.k, f=a.f) if zero_output: t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y), diff --git a/Compiler/types.py b/Compiler/types.py index 3fdc6cf05..0063fdc16 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -4274,6 +4274,7 @@ class sfix(_fix): :params _v: int/float/regint/cint/sint/sfloat """ int_type = sint + bit_type = sintbit clear_type = cfix @vectorized_classmethod