diff --git a/.gitignore b/.gitignore index 2039841..d407dc9 100644 --- a/.gitignore +++ b/.gitignore @@ -131,3 +131,5 @@ dmypy.json # vim *.swp *.pkl +.vscode/settings.json +test/wrap.exe diff --git a/README.md b/README.md index fe02336..2048802 100644 --- a/README.md +++ b/README.md @@ -220,6 +220,28 @@ A few notes: | `celu` | Available | | Only for alpha=1 (celu == elu) | +## Vectorization +An attempt to use vectorization and cache-aware blocking in matrix multiplication was done, but it +may hinder portability of the produced code and must therefore be enabled explicitly. + +To enable vectorization with single-precision floating point representation (32 bits), compile with +``` + -mfma -mavx2 -DFLOAT_T=float -DUSE_AVX2_32 +``` +to use vectorization with double-precision representation (64 bits), use +``` + -mfma -mavx2 -DFLOAT_T=float -DUSE_AVX2_64 +``` + +These flags: + * Enable AVX2 and FMA instructions. + * Define the floating-point type used throughout the code. + * Select the correct vectorized implementation based on precision. + +### Portability notice +AVX2 and FMA are supported on most modern x86_64 CPUs, but not guaranteed on older or embedded systems. +If unsure, compile without any of the flags above to fall back to the scalar implementation. + ## Running tests In order to install the full dependencies needed to test the whole package, install with the tag `fql`. diff --git a/scikinC/ColumnTransformerConverter.py b/scikinC/ColumnTransformerConverter.py index bb6d064..a744286 100644 --- a/scikinC/ColumnTransformerConverter.py +++ b/scikinC/ColumnTransformerConverter.py @@ -1,4 +1,5 @@ import numpy as np +from copy import deepcopy from sklearn.preprocessing import FunctionTransformer @@ -30,10 +31,10 @@ def convert(self, model, name=None): if key is None: key = "Preprocessor" if key in keys: - key.append (str(1+len(keys))) + key = key + str(1+len(keys)) if isinstance(transformer, (FunctionTransformer,)): - if transformer.func is None and transformer.inverse_func is None: + if getattr(transformer, 'func', None) is None and getattr(transformer, 'inverse_func', None) is None: transformer = 'passthrough' else: transformer.n_features_in_ = len(columns) @@ -46,10 +47,8 @@ def convert(self, model, name=None): scikinC.convert({k: t for k,t,_ in transformers if t != 'passthrough'}) ) - mapping = {k: c for k,_,c in transformers} - - nFeatures = 1+max(index_mapping) - + nFeatures = max(1+max(index_mapping), len(index_mapping)) + lines.append(""" extern "C" FLOAT_T* %(name)s (FLOAT_T* ret, const FLOAT_T *input) @@ -62,13 +61,16 @@ def convert(self, model, name=None): ) ) + # This is a copy where used indices are set to None once processed + id_map_tmp = deepcopy(index_mapping) for key, transformer, columns in transformers: lines.append("// Transforming %s columns" % key) if transformer == 'passthrough': for column in columns: lines.append(""" ret [%(output)d] = input[%(column)d]; - """%dict(output=index_mapping.index(column), column=column)) + """%dict(output=id_map_tmp.index(column), column=column)) + id_map_tmp[id_map_tmp.index(column)] = None else: for iCol, column in enumerate(columns): lines.append(""" bufin [%(iCol)d] = input[%(column)d];"""% @@ -77,7 +79,8 @@ def convert(self, model, name=None): % dict(name=key)) for iCol, column in enumerate(columns): lines.append(""" ret[%(index_out)d] = bufout[%(iCol)d];"""% - dict(index_out=index_mapping.index(column), iCol=iCol)) + dict(index_out=id_map_tmp.index(column), iCol=iCol)) + id_map_tmp[id_map_tmp.index(column)] = None lines.append (""" return ret; diff --git a/scikinC/layers/BaseLayerConverter.py b/scikinC/layers/BaseLayerConverter.py index e505149..6482d9a 100644 --- a/scikinC/layers/BaseLayerConverter.py +++ b/scikinC/layers/BaseLayerConverter.py @@ -13,13 +13,13 @@ def activate (self, x): activation = activation.__name__ if activation == 'sigmoid': - return "%(x)s = 1. / (1+exp(-%(x)s));" % {'x':x} + return "%(x)s = 1. / (1+expf(-%(x)s));" % {'x':x} elif activation == 'tanh': - return "%(x)s = tanh(%(x)s);" % {'x':x} + return "%(x)s = tanhf(%(x)s);" % {'x':x} elif activation == 'relu': return "%(x)s = %(x)s > 0. ? %(x)s : 0.;" % {'x':x} elif activation == 'celu' or activation == 'elu': - return "%(x)s = %(x)s > 0. ? %(x)s : exp(%(x)s) - 1;" % {'x': x} + return "%(x)s = %(x)s > 0. ? %(x)s : expf(%(x)s) - 1;" % {'x': x} elif activation == 'linear': return "" else: diff --git a/scikinC/layers/Dense.py b/scikinC/layers/Dense.py index 15b5964..7716abc 100644 --- a/scikinC/layers/Dense.py +++ b/scikinC/layers/Dense.py @@ -8,43 +8,121 @@ class Dense (BaseLayerConverter): def definition(self): """Return the definition of the layer function""" - ret = [] - nX, nY = self.layer.kernel.shape - kernel, bias = self.layer.get_weights() + c_code = """ +#if defined(USE_AVX2_32) || defined(USE_AVX2_64) + #include + #include +#endif + +#ifndef CACHE_LINE_SIZE +#define CACHE_LINE_SIZE 64 +#endif + + extern "C" + FLOAT_T* %(layername)s (FLOAT_T* ret, const FLOAT_T* input) + { + int i, j, ii, jj; + static const FLOAT_T kernel[%(nY)d][%(nX)d] = %(kernel_values)s; + static const FLOAT_T bias[%(nY)d] = %(bias_values)s; + + // Block sizes + const int BLOCK_I = 32; + const int BLOCK_J = CACHE_LINE_SIZE / sizeof(FLOAT_T); + + +#if defined(USE_AVX2_32) + const int word_size = 8; // 256 bits / 32 bits + memcpy(ret, bias, sizeof(FLOAT_T)*%(nY)d); + + // Blocked scalar version for float with AVX2 + for (ii = 0; ii < %(nY)d; ii += BLOCK_I) { + int i_max = (ii + BLOCK_I < %(nY)d) ? ii + BLOCK_I : %(nY)d; + for (jj = 0; jj < %(nX)d; jj += BLOCK_J) { + const int j_max = (jj + BLOCK_J < %(nX)d) ? jj + BLOCK_J : %(nX)d; + for (i = ii; i < i_max; ++i) { + __m256 sum = _mm256_setzero_ps(); + for (j = jj; j + word_size <= j_max; j += word_size) { + __m256 in_vec = _mm256_loadu_ps(&input[j]); + __m256 ker_vec = _mm256_loadu_ps(&kernel[i][j]); + sum = _mm256_fmadd_ps(in_vec, ker_vec, sum); + } - ret += [""" - extern "C" - FLOAT_T* %(layername)s (FLOAT_T* ret, const FLOAT_T* input) - { - int i, j; - const FLOAT_T kernel[%(nY)d][%(nX)d] = %(kernel_values)s; - const FLOAT_T bias[%(nY)d] = %(bias_values)s; - - for (i=0; i < %(nY)d; ++i) - { - ret[i] = bias[i]; - const FLOAT_T *row = kernel[i]; - for (j=0; j<%(nX)d; ++j) - ret[i] += input[j] * row[j]; - - %(activate)s + float temp[word_size]; + _mm256_storeu_ps(temp, sum); + for (int k = 0; k < word_size; ++k) + ret[i] += temp[k]; + + // Scalar tail + for (; j < j_max; ++j) { + ret[i] += input[j] * kernel[i][j]; } + } + } + } +#elif defined(USE_AVX2_64) + const int word_size = 4; // 256 bits / 64 bits + memcpy(ret, bias, sizeof(FLOAT_T)*%(nY)d); + + // Blocked scalar version for double with AVX2 + for (ii = 0; ii < %(nY)d; ii += BLOCK_I) { + int i_max = (ii + BLOCK_I < %(nY)d) ? ii + BLOCK_I : %(nY)d; + for (jj = 0; jj < %(nX)d; jj += BLOCK_J) { + const int j_max = (jj + BLOCK_J < %(nX)d) ? jj + BLOCK_J : %(nX)d; + for (i = ii; i < i_max; ++i) { + __m256d sum = _mm256_setzero_pd(); + for (j = jj; j + word_size <= j_max; j += word_size) { + __m256d in_vec = _mm256_loadu_pd(&input[j]); + __m256d ker_vec = _mm256_loadu_pd(&kernel[i][j]); + sum = _mm256_fmadd_pd(in_vec, ker_vec, sum); + } + + double temp[word_size]; + _mm256_storeu_pd(temp, sum); + for (int k = 0; k < word_size; ++k) + ret[i] += temp[k]; - return ret; + // Scalar tail + for (; j < j_max; ++j) { + ret[i] += input[j] * kernel[i][j]; + } + } + } + } +#else + for (i = 0; i < %(nY)d; ++i) + ret[i] = bias[i]; + + // Blocked scalar version (used for double or if AVX2 is not enabled) + for (ii = 0; ii < %(nY)d; ii += BLOCK_I) { + int i_max = (ii + BLOCK_I < %(nY)d) ? ii + BLOCK_I : %(nY)d; + for (jj = 0; jj < %(nX)d; jj += BLOCK_J) { + int j_max = (jj + BLOCK_J < %(nX)d) ? jj + BLOCK_J : %(nX)d; + for (i = ii; i < i_max; ++i) { + for (j = jj; j < j_max; ++j) { + ret[i] += input[j] * kernel[i][j]; + } + } } - """ % dict( - layername = self.name, - nX = nX, - nY = nY, - kernel_values = array2c (kernel.T), - bias_values = array2c (bias), - activate = self.activate('ret[i]'), - )] + } +#endif + for (i = 0; i < %(nY)d; ++i) { + %(activate)s + } - return "\n".join(ret) + return ret; + } + """ % dict( + layername=self.name, + nX=nX, + nY=nY, + kernel_values=array2c(kernel.T), + bias_values=array2c(bias), + activate=self.activate('ret[i]'), + ) + return c_code def call(self, obuffer, ibuffer): """Return the call to the layer function""" diff --git a/test/test_ColumnTransformerConverter.py b/test/test_ColumnTransformerConverter.py index 3d67225..5ae1ab4 100644 --- a/test/test_ColumnTransformerConverter.py +++ b/test/test_ColumnTransformerConverter.py @@ -29,6 +29,18 @@ def double_passthrough_transformer(): transformer_.fit (X) return transformer_ +@fixtures.register() +def expanding_passthrough_transformer(): + transformer_ = ColumnTransformer([ + ('keep1', 'passthrough', [0]), + ('keep2', 'passthrough', [0, 1]), + ('keep3', 'passthrough', [0, 1]), + ]) + X = np.random.uniform (20,30,(1000, 10)) + transformer_.fit (X) + return transformer_ + + @fixtures.register('invertible') def ss_and_passthrough_transformer(): diff --git a/test/test_keras.py b/test/test_keras.py index cc95f72..167298b 100644 --- a/test/test_keras.py +++ b/test/test_keras.py @@ -53,6 +53,13 @@ def test_dense (classifier_dense): deployed = deploy_keras("keras_dense", classifier_dense) assert eval_error (classifier_dense, deployed) < 1e-5 +def test_dense_with_avx2 (classifier_dense): + deployed = deploy_keras("keras_dense_with_avx2", classifier_dense, use_avx2=True) + assert eval_error (classifier_dense, deployed) < 1e-5 + +def test_dense_with_doubles_and_avx2 (classifier_dense): + deployed = deploy_keras("keras_dense_with_doubles_and_avx2", classifier_dense, use_avx2=True, float_t='double') + assert eval_error (classifier_dense, deployed) < 1e-5 ################################################################################ ### PReLU layer diff --git a/test/wrap.py b/test/wrap.py index 31752d5..27f3d89 100644 --- a/test/wrap.py +++ b/test/wrap.py @@ -8,9 +8,11 @@ class DeployedModel: - def __init__(self, filename, compiled = 'test.so'): + def __init__(self, filename, compiled='test.so', use_avx2=False, float_t='float'): self.filename = filename self.compiled = compiled + self.use_avx2 = use_avx2 + self.float_t = float_t self.compile() self.funcnames = self.get_funcnames() @@ -22,8 +24,22 @@ def __del__ (self): os.system ( "rm %s" % (self.compiled, ) ) def compile (self): + avx2_flags = [] if not self.use_avx2 else [ + "-DUSE_AVX2_32" if self.float_t == 'float' else "-DUSE_AVX2_64", + "-mavx2", "-mfma" + ] + + float_t_flag = ["-DFLOAT_T=%s" % self.float_t] + + if self.use_avx2: + print ("Compiling with AVX2 support") + else: + print ("Compiling without AVX2 support") + output = subprocess.check_output( ["gcc", self.filename, "-o", self.compiled, "--shared", "-fPIC", "-lm"] + + float_t_flag + + avx2_flags ) if str(output, 'ASCII') not in ["", "\n"]: raise Exception("Compilation error %s" % str(output, 'ASCII')) @@ -37,7 +53,8 @@ def get_funcnames(self): ret = [] for line in str(output, 'ASCII').split('\n'): tokens = [a for a in line.split(' ') if len(a)] - if len(tokens) != 3: continue + if len(tokens) != 3: + continue addr, type_, name = tokens if type_ in "T": ret.append (name) @@ -76,13 +93,13 @@ def deploy_pickle (name, obj, float_t = "float"): ) - ret = DeployedModel(tmpfile+".C", compiled = './%s.so' % tmpfile) + ret = DeployedModel(tmpfile+".C", compiled = './%s.so' % tmpfile, float_t=float_t) os.system ("rm %(tmpfile)s.pkl %(tmpfile)s.C" % {'tmpfile': tmpfile} ) return ret -def deploy_keras (name, obj, float_t = "float"): +def deploy_keras (name, obj, float_t = "float", use_avx2=False): ### Randomize UID s = string.ascii_letters uid = [s[np.random.randint(len(s))] for _ in range(16)] @@ -101,7 +118,7 @@ def deploy_keras (name, obj, float_t = "float"): ) - ret = DeployedModel(tmpfile+".C", compiled = './%s.so' % tmpfile) + ret = DeployedModel(tmpfile+".C", compiled='./%s.so' % tmpfile, use_avx2=use_avx2, float_t=float_t) os.system ("rm -r %(tmpfile)s %(tmpfile)s.C" % {'tmpfile': tmpfile} )