Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,5 @@ dmypy.json
# vim
*.swp
*.pkl
.vscode/settings.json
test/wrap.exe
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,28 @@ A few notes:
| `celu` | Available | | Only for alpha=1 (celu == elu) |


## Vectorization
An attempt to use vectorization and cache-aware blocking in matrix multiplication was done, but it
may hinder portability of the produced code and must therefore be enabled explicitly.

To enable vectorization with single-precision floating point representation (32 bits), compile with
```
-mfma -mavx2 -DFLOAT_T=float -DUSE_AVX2_32
```
to use vectorization with double-precision representation (64 bits), use
```
-mfma -mavx2 -DFLOAT_T=float -DUSE_AVX2_64
```

These flags:
* Enable AVX2 and FMA instructions.
* Define the floating-point type used throughout the code.
* Select the correct vectorized implementation based on precision.

### Portability notice
AVX2 and FMA are supported on most modern x86_64 CPUs, but not guaranteed on older or embedded systems.
If unsure, compile without any of the flags above to fall back to the scalar implementation.

## Running tests
In order to install the full dependencies needed to test the whole package,
install with the tag `fql`.
Expand Down
19 changes: 11 additions & 8 deletions scikinC/ColumnTransformerConverter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from copy import deepcopy

from sklearn.preprocessing import FunctionTransformer

Expand Down Expand Up @@ -30,10 +31,10 @@ def convert(self, model, name=None):
if key is None:
key = "Preprocessor"
if key in keys:
key.append (str(1+len(keys)))
key = key + str(1+len(keys))

if isinstance(transformer, (FunctionTransformer,)):
if transformer.func is None and transformer.inverse_func is None:
if getattr(transformer, 'func', None) is None and getattr(transformer, 'inverse_func', None) is None:
transformer = 'passthrough'
else:
transformer.n_features_in_ = len(columns)
Expand All @@ -46,10 +47,8 @@ def convert(self, model, name=None):
scikinC.convert({k: t for k,t,_ in transformers if t != 'passthrough'})
)

mapping = {k: c for k,_,c in transformers}

nFeatures = 1+max(index_mapping)

nFeatures = max(1+max(index_mapping), len(index_mapping))

lines.append("""
extern "C"
FLOAT_T* %(name)s (FLOAT_T* ret, const FLOAT_T *input)
Expand All @@ -62,13 +61,16 @@ def convert(self, model, name=None):
)
)

# This is a copy where used indices are set to None once processed
id_map_tmp = deepcopy(index_mapping)
for key, transformer, columns in transformers:
lines.append("// Transforming %s columns" % key)
if transformer == 'passthrough':
for column in columns:
lines.append("""
ret [%(output)d] = input[%(column)d];
"""%dict(output=index_mapping.index(column), column=column))
"""%dict(output=id_map_tmp.index(column), column=column))
id_map_tmp[id_map_tmp.index(column)] = None
else:
for iCol, column in enumerate(columns):
lines.append(""" bufin [%(iCol)d] = input[%(column)d];"""%
Expand All @@ -77,7 +79,8 @@ def convert(self, model, name=None):
% dict(name=key))
for iCol, column in enumerate(columns):
lines.append(""" ret[%(index_out)d] = bufout[%(iCol)d];"""%
dict(index_out=index_mapping.index(column), iCol=iCol))
dict(index_out=id_map_tmp.index(column), iCol=iCol))
id_map_tmp[id_map_tmp.index(column)] = None

lines.append ("""
return ret;
Expand Down
6 changes: 3 additions & 3 deletions scikinC/layers/BaseLayerConverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ def activate (self, x):
activation = activation.__name__

if activation == 'sigmoid':
return "%(x)s = 1. / (1+exp(-%(x)s));" % {'x':x}
return "%(x)s = 1. / (1+expf(-%(x)s));" % {'x':x}
elif activation == 'tanh':
return "%(x)s = tanh(%(x)s);" % {'x':x}
return "%(x)s = tanhf(%(x)s);" % {'x':x}
elif activation == 'relu':
return "%(x)s = %(x)s > 0. ? %(x)s : 0.;" % {'x':x}
elif activation == 'celu' or activation == 'elu':
return "%(x)s = %(x)s > 0. ? %(x)s : exp(%(x)s) - 1;" % {'x': x}
return "%(x)s = %(x)s > 0. ? %(x)s : expf(%(x)s) - 1;" % {'x': x}
elif activation == 'linear':
return ""
else:
Expand Down
136 changes: 107 additions & 29 deletions scikinC/layers/Dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,121 @@ class Dense (BaseLayerConverter):

def definition(self):
"""Return the definition of the layer function"""
ret = []

nX, nY = self.layer.kernel.shape

kernel, bias = self.layer.get_weights()
c_code = """
#if defined(USE_AVX2_32) || defined(USE_AVX2_64)
#include <immintrin.h>
#include <string.h>
#endif

#ifndef CACHE_LINE_SIZE
#define CACHE_LINE_SIZE 64
#endif

extern "C"
FLOAT_T* %(layername)s (FLOAT_T* ret, const FLOAT_T* input)
{
int i, j, ii, jj;
static const FLOAT_T kernel[%(nY)d][%(nX)d] = %(kernel_values)s;
static const FLOAT_T bias[%(nY)d] = %(bias_values)s;

// Block sizes
const int BLOCK_I = 32;
const int BLOCK_J = CACHE_LINE_SIZE / sizeof(FLOAT_T);


#if defined(USE_AVX2_32)
const int word_size = 8; // 256 bits / 32 bits
memcpy(ret, bias, sizeof(FLOAT_T)*%(nY)d);

// Blocked scalar version for float with AVX2
for (ii = 0; ii < %(nY)d; ii += BLOCK_I) {
int i_max = (ii + BLOCK_I < %(nY)d) ? ii + BLOCK_I : %(nY)d;
for (jj = 0; jj < %(nX)d; jj += BLOCK_J) {
const int j_max = (jj + BLOCK_J < %(nX)d) ? jj + BLOCK_J : %(nX)d;
for (i = ii; i < i_max; ++i) {
__m256 sum = _mm256_setzero_ps();
for (j = jj; j + word_size <= j_max; j += word_size) {
__m256 in_vec = _mm256_loadu_ps(&input[j]);
__m256 ker_vec = _mm256_loadu_ps(&kernel[i][j]);
sum = _mm256_fmadd_ps(in_vec, ker_vec, sum);
}

ret += ["""
extern "C"
FLOAT_T* %(layername)s (FLOAT_T* ret, const FLOAT_T* input)
{
int i, j;
const FLOAT_T kernel[%(nY)d][%(nX)d] = %(kernel_values)s;
const FLOAT_T bias[%(nY)d] = %(bias_values)s;

for (i=0; i < %(nY)d; ++i)
{
ret[i] = bias[i];
const FLOAT_T *row = kernel[i];
for (j=0; j<%(nX)d; ++j)
ret[i] += input[j] * row[j];

%(activate)s
float temp[word_size];
_mm256_storeu_ps(temp, sum);
for (int k = 0; k < word_size; ++k)
ret[i] += temp[k];

// Scalar tail
for (; j < j_max; ++j) {
ret[i] += input[j] * kernel[i][j];
}
}
}
}
#elif defined(USE_AVX2_64)
const int word_size = 4; // 256 bits / 64 bits
memcpy(ret, bias, sizeof(FLOAT_T)*%(nY)d);

// Blocked scalar version for double with AVX2
for (ii = 0; ii < %(nY)d; ii += BLOCK_I) {
int i_max = (ii + BLOCK_I < %(nY)d) ? ii + BLOCK_I : %(nY)d;
for (jj = 0; jj < %(nX)d; jj += BLOCK_J) {
const int j_max = (jj + BLOCK_J < %(nX)d) ? jj + BLOCK_J : %(nX)d;
for (i = ii; i < i_max; ++i) {
__m256d sum = _mm256_setzero_pd();
for (j = jj; j + word_size <= j_max; j += word_size) {
__m256d in_vec = _mm256_loadu_pd(&input[j]);
__m256d ker_vec = _mm256_loadu_pd(&kernel[i][j]);
sum = _mm256_fmadd_pd(in_vec, ker_vec, sum);
}

double temp[word_size];
_mm256_storeu_pd(temp, sum);
for (int k = 0; k < word_size; ++k)
ret[i] += temp[k];

return ret;
// Scalar tail
for (; j < j_max; ++j) {
ret[i] += input[j] * kernel[i][j];
}
}
}
}
#else
for (i = 0; i < %(nY)d; ++i)
ret[i] = bias[i];

// Blocked scalar version (used for double or if AVX2 is not enabled)
for (ii = 0; ii < %(nY)d; ii += BLOCK_I) {
int i_max = (ii + BLOCK_I < %(nY)d) ? ii + BLOCK_I : %(nY)d;
for (jj = 0; jj < %(nX)d; jj += BLOCK_J) {
int j_max = (jj + BLOCK_J < %(nX)d) ? jj + BLOCK_J : %(nX)d;
for (i = ii; i < i_max; ++i) {
for (j = jj; j < j_max; ++j) {
ret[i] += input[j] * kernel[i][j];
}
}
}
""" % dict(
layername = self.name,
nX = nX,
nY = nY,
kernel_values = array2c (kernel.T),
bias_values = array2c (bias),
activate = self.activate('ret[i]'),
)]
}
#endif

for (i = 0; i < %(nY)d; ++i) {
%(activate)s
}

return "\n".join(ret)
return ret;
}
""" % dict(
layername=self.name,
nX=nX,
nY=nY,
kernel_values=array2c(kernel.T),
bias_values=array2c(bias),
activate=self.activate('ret[i]'),
)
return c_code

def call(self, obuffer, ibuffer):
"""Return the call to the layer function"""
Expand Down
12 changes: 12 additions & 0 deletions test/test_ColumnTransformerConverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ def double_passthrough_transformer():
transformer_.fit (X)
return transformer_

@fixtures.register()
def expanding_passthrough_transformer():
transformer_ = ColumnTransformer([
('keep1', 'passthrough', [0]),
('keep2', 'passthrough', [0, 1]),
('keep3', 'passthrough', [0, 1]),
])
X = np.random.uniform (20,30,(1000, 10))
transformer_.fit (X)
return transformer_



@fixtures.register('invertible')
def ss_and_passthrough_transformer():
Expand Down
7 changes: 7 additions & 0 deletions test/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def test_dense (classifier_dense):
deployed = deploy_keras("keras_dense", classifier_dense)
assert eval_error (classifier_dense, deployed) < 1e-5

def test_dense_with_avx2 (classifier_dense):
deployed = deploy_keras("keras_dense_with_avx2", classifier_dense, use_avx2=True)
assert eval_error (classifier_dense, deployed) < 1e-5

def test_dense_with_doubles_and_avx2 (classifier_dense):
deployed = deploy_keras("keras_dense_with_doubles_and_avx2", classifier_dense, use_avx2=True, float_t='double')
assert eval_error (classifier_dense, deployed) < 1e-5

################################################################################
### PReLU layer
Expand Down
27 changes: 22 additions & 5 deletions test/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@


class DeployedModel:
def __init__(self, filename, compiled = 'test.so'):
def __init__(self, filename, compiled='test.so', use_avx2=False, float_t='float'):
self.filename = filename
self.compiled = compiled
self.use_avx2 = use_avx2
self.float_t = float_t

self.compile()
self.funcnames = self.get_funcnames()
Expand All @@ -22,8 +24,22 @@ def __del__ (self):
os.system ( "rm %s" % (self.compiled, ) )

def compile (self):
avx2_flags = [] if not self.use_avx2 else [
"-DUSE_AVX2_32" if self.float_t == 'float' else "-DUSE_AVX2_64",
"-mavx2", "-mfma"
]

float_t_flag = ["-DFLOAT_T=%s" % self.float_t]

if self.use_avx2:
print ("Compiling with AVX2 support")
else:
print ("Compiling without AVX2 support")

output = subprocess.check_output(
["gcc", self.filename, "-o", self.compiled, "--shared", "-fPIC", "-lm"]
+ float_t_flag
+ avx2_flags
)
if str(output, 'ASCII') not in ["", "\n"]:
raise Exception("Compilation error %s" % str(output, 'ASCII'))
Expand All @@ -37,7 +53,8 @@ def get_funcnames(self):
ret = []
for line in str(output, 'ASCII').split('\n'):
tokens = [a for a in line.split(' ') if len(a)]
if len(tokens) != 3: continue
if len(tokens) != 3:
continue
addr, type_, name = tokens
if type_ in "T":
ret.append (name)
Expand Down Expand Up @@ -76,13 +93,13 @@ def deploy_pickle (name, obj, float_t = "float"):
)


ret = DeployedModel(tmpfile+".C", compiled = './%s.so' % tmpfile)
ret = DeployedModel(tmpfile+".C", compiled = './%s.so' % tmpfile, float_t=float_t)

os.system ("rm %(tmpfile)s.pkl %(tmpfile)s.C" % {'tmpfile': tmpfile} )

return ret

def deploy_keras (name, obj, float_t = "float"):
def deploy_keras (name, obj, float_t = "float", use_avx2=False):
### Randomize UID
s = string.ascii_letters
uid = [s[np.random.randint(len(s))] for _ in range(16)]
Expand All @@ -101,7 +118,7 @@ def deploy_keras (name, obj, float_t = "float"):
)


ret = DeployedModel(tmpfile+".C", compiled = './%s.so' % tmpfile)
ret = DeployedModel(tmpfile+".C", compiled='./%s.so' % tmpfile, use_avx2=use_avx2, float_t=float_t)

os.system ("rm -r %(tmpfile)s %(tmpfile)s.C" % {'tmpfile': tmpfile} )

Expand Down