Skip to content

Commit

Permalink
More examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Oct 12, 2018
1 parent cc0711c commit 3c69da1
Show file tree
Hide file tree
Showing 7 changed files with 311 additions and 1 deletion.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ callgrind.out.*

# Compiled source #
###################
Programs/Source/*
Programs/Bytecode/*
Programs/Schedules/*
Programs/Public-Input/*
Expand Down
100 changes: 100 additions & 0 deletions Programs/Source/blink.mpc
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import math
import util

n_threads = 64
xor_op = lambda x, y: x ^ y
n_bits = 64
full_t = sbits.get_type(64)
sbits.n = n_bits

if len(program.args) > 1:
n_batches = int(program.args[1])
else:
n_batches = 78

batch_size = 64
n = n_batches * batch_size
l = 16
a = Matrix(n, l, full_t)
b = Matrix(n, l, full_t)
t = sbitint.get_type(int(math.ceil(math.log(batch_size * l, 2))) + 1)
matches = Matrix(n, n, t.bit_type)
mismatches = Matrix(n, n, t)
threshold = MemValue(t(10))

for i in range(n):
for j in range(l):
a[i][j] = full_t.get_input_from(0)
b[i][j] = full_t.get_input_from(1)

# test, create match between a[0] and b[1] but no match for a[1]
a.assign_all(0)
b.assign_all(0)
a[0][0] = -1
b[1][0] = -1
a[1][1] = -1

@for_range_multithread(n_batches, 1, n)
def _(i):
print_ln('%s', i)
@for_range_parallel(100, n_batches)
def _(j):
j = j * batch_size
av = sbitintvec.from_matrix((a[i][kk] for _ in range(batch_size)) \
for kk in range(l))
bv = sbitintvec.from_matrix((b[j + k][kk] for k in range(batch_size)) \
for kk in range(l))
res = xor_op(av, bv).popcnt()
mismatches[i].set_range(j, (t(x) for x in res.elements()))

@for_range_multithread(n_batches, 8, n)
def _(i):
print_ln('%s', i)
@for_range_parallel(100, n_batches)
def _(j):
j = j * batch_size
v = sbitintvec(mismatches[i].get_range(j, batch_size))
vv = sbitintvec([threshold.read()] * batch_size)
matches[i].set_range(j, v.less_than(vv, 10).elements())

mg = MultiArray([n_batches, n, t.n], full_t)
ag = Matrix(n_batches, n, full_t)

@for_range_multithread(n_batches, 1, n_batches)
def _(i):
m = mg[i]
a = ag[i]
i = i * batch_size
print_ln('best %s', i)
@for_range(n)
def _(j):
m[j].assign(sbitintvec(mismatches[i + k][j]
for k in range(batch_size)).v)
m = [sbitintvec.from_vec(m[j]) for j in range(n)]
def reducer(a, b):
c = a[0].less_than(b[0])
return util.if_else(c, (a[0], a[1] + [0] * len(b[1])),
(b[0], [0] * len(a[1]) + b[1]))
mm = util.tree_reduce(reducer, ((x, [2**batch_size - 1]) for x in m))
a.assign(mm[1])
@for_range_parallel(100, len(a))
def _(j):
x = a[j]
pm = sbitintvec(matches[i + k][j] for k in range(batch_size))
x = sbitintvec.from_vec([x])
for k, y in enumerate((pm & x).elements()):
matches[i + k][j] = y

def test(result, expected):
print_ln('%s ?= %s', result.reveal(), expected)

test(matches[0][1], 1)
test(matches[0][0], 0)
test(matches[1][0], 0)
test(matches[1][1], 0)
test(sum(matches[2]), 1)

test(mismatches[0][1], 0)
test(mismatches[0][0], 64)
test(mismatches[1][0], 64)
test(mismatches[1][1], 128)
26 changes: 26 additions & 0 deletions Programs/Source/gc_and.mpc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from Compiler.GC.types import sbits, sbit, cbits


import random

n = 4096
m = 1

if len(program.args) > 1:
n = int(program.args[1])

if len(program.args) > 2:
m = int(program.args[2])

pack = min(n, 50)
n = (n + pack - 1) / pack

a = sbit(1)
b = sbit(1, n=pack)

start_timer(1)
@for_range(m)
def f(_):
for i in range(n):
a * b
stop_timer(1)
44 changes: 44 additions & 0 deletions Programs/Source/gc_fixed_point_tutorial.mpc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
sfix = sbitfix
sint = sbitint.get_type(20)

sfix.set_precision(16, 32)

n = 10
m = 5

# array of fixed points
A = Array(n, sfix)

for i in range(n):
A[i] = sfix(i)

print_ln('mrray of fixed points')
for i in range(n):
print_ln('%s', A[i].reveal())

# matrix of fixed points
M = Matrix(n, m, sfix)

for i in range(n):
for j in range(m):
M[i][j] = sfix(i*j)

print_ln('matrix of fixed points')
for i in range(n):
for j in range(m):
print_str('%s ', M[i][j].reveal())
print_ln(' ')


# assign scalar to sfix
A[5] = sfix(1.12345)
print_ln('%s', A[5].reveal())

# assign sint to sfix
s = sint(10)
sa = sfix(); sa.load_int(s)
print_ln('successfully assigned sint to sfix %s', sa.reveal())

# division between fixed points
sb = sfix(2.5)
print_ln('division between %s %s = %s', sa.reveal(), sb.reveal(), (sa/sb).reveal())
50 changes: 50 additions & 0 deletions Programs/Source/gc_tutorial.mpc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# sbitint: factory for signed integer types

sint = sbitint.get_type(32)

def test(a, b, value_type=None):
try:
a = a.reveal()
except AttributeError:
pass
import inspect
print_ln('line %s: diff %s, got %s, expected %s',
inspect.currentframe().f_back.f_lineno, \
(a ^ cbits(b, n=a.n)).reveal(), a, hex(b))

a = sint(1)
b = sint(2)

test(a + b, 3)
test(a + a, 2)
test(a * b, 2)
test(a * a, 1)
test(a - b, -1)
test(a < b, 1)
test(a <= b, 1)
test(a >= b, 0)
test(a > b, 0)
test(a == b, 0)
test(a != b, 1)

clear_a = a.reveal()

# arrays and loops

a = Array(100, sint)

@for_range(100)
def f(i):
a[i] = sint(i)**2

test(a[99], 99**2)

# conditional

if_then(regint(0))
a[0] = 123
else_then()
a[0] = 789
end_if()

test(a[0], 789)
53 changes: 53 additions & 0 deletions Programs/Source/test_sbitfix.mpc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from Compiler.GC.types import sbitfix, cbits

#sbitfix.set_precision(3, 7)

def test(a, b, value_type=None):
try:
b = int(round((b * (1 << a.f))))
a = a.v.reveal()
except AttributeError:
pass
try:
a = a.reveal()
except AttributeError:
pass
import inspect
print_ln('%s: %s %s %s', inspect.currentframe().f_back.f_lineno, \
(a ^ cbits(b)).reveal(), a, (b))

aa = 5321.0
bb = 142.0

for a_sign, b_sign in (1, -1), (-1, -1):
a = a_sign * aa
b = b_sign * bb

sa = sbitfix(a)
sb = sbitfix(b)

test(sa + sb, a+b)
test(sa - sb, a-b)
test(sa * sb, a*b)
test(sa / sb, a/b)

test(-sa, -a)

a = 126
b = 125
sa = sbitfix(a)
sb = sbitfix(b)

test(sa < sb, int(a<b))
test(sa < sa, int(a<a))
test(sa < sa + sbitfix(1), int(a<a+1))
test(-sa < sa, int(-a<a))
test(-sb < sb, int(-b<b))
test(sa < -sb, int(a<-b))
test(-sa < -sb, int(-a<-b))
test(sa > sb, int(a>b))
test(sa <= sb, int(a<=b))
test(sa >= sb, int(a>=b))
test(sa == sb, int(a==b))
test(sa != sb, int(a!=b))
test(sa != sa, int(a!=a))
38 changes: 38 additions & 0 deletions Programs/Source/test_sbitint.mpc
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
program.options.merge_opens = False

from Compiler.GC.types import *

def test(a, b, value_type=None):
try:
a = a.reveal()
except AttributeError:
pass
import inspect
print_ln('%s: %s %s %s', inspect.currentframe().f_back.f_lineno, \
(a ^ cbits(b)).reveal(), a, hex(b))

si32 = sbitint.get_type(32)

test(si32(3) + si32(2), 5)
test(si32(3) - si32(2), 1)
test(si32(3) < si32(2), 0)
test(si32(3) > si32(2), 1)
test(si32(2) <= si32(2), 1)
test((si32(0) < si32(1)).if_else(si32(1), si32(2)) + si32(3), 4)

test(si32(3) * si32(2), 6)
test(3 * si32(2), 6)
test(si32(3) * 2, 6)

test(si32(-1), 2**32 - 1)
test(si32(-1) + si32(3), 2)
test(si32(-1) - si32(-2), 1)

test(si32(1) * 2 * 2, 4)

for i in range(3, 32):
t = sbitint.get_type(i)
test(t(3) + t(2), 5)

test(abs(si32(-2)), 2)
test(abs(si32(2)), 2)

0 comments on commit 3c69da1

Please sign in to comment.