From ef10157f7b7ac10ec99def426c34e5395616af37 Mon Sep 17 00:00:00 2001 From: "Alexey S. Kiselev" Date: Sat, 28 Sep 2019 18:28:16 +0100 Subject: [PATCH] Crossvalidation type added for k-fold method, bug fixed (#6) --- Readme.md | 22 +++++--- core/array_manipulation/kfold.js | 92 +++++++++++++++++++++++++++----- core/types.js | 11 ++++ package.json | 2 +- test/array_manipulation.test.js | 69 ++++++++++++++++++++++++ test/testBuild.js | 5 +- 6 files changed, 180 insertions(+), 21 deletions(-) diff --git a/Readme.md b/Readme.md index 0d29479..c975324 100644 --- a/Readme.md +++ b/Readme.md @@ -4,7 +4,7 @@ # Unirand A JavaScript module for generating seeded random distributions and its statistical analysis. -Implemented in pure JavaScript with no dependencies, designed to work in Node.js and fully asynchronous, tested *with ~600 tests*. +Implemented in pure JavaScript with no dependencies, designed to work in Node.js and fully asynchronous, tested *with 600+ tests*. [Supported distributions](./core/methods/) @@ -17,7 +17,7 @@ const unirand = require('unirand') ``` ### PRNG -Unirand supports different PRNGs: *default JS generator, tuchei sedded generator*. By default unirand uses **tuchei** generator. +Unirand supports different PRNGs: *default JS generator, tuchei seeded generator*. By default unirand uses **tuchei** generator. Our seeded generator supports *seed*, *random*, *next* methods. A name of current using PRNG is stored in: ```javascript @@ -158,8 +158,8 @@ Sample method is **3 times faster** for arrays and **7 times faster** for string ### k-fold Splits array into *k* subarrays. Requires at least 2 arguments: array itself and *k*. Also supports *options*. -- *type*: output type, **list** (default) for output like `[, , , ...]`, **set** for output like `{0: , 1: , 2: , ...}` -- *derange*: items will be shuffled as *random permutation* (default) or *random derangement* +- *type*: output type, **list** (default) for output like `[, , , ...]`, **set** for output like `{0: , 1: , 2: , ...}`, **crossvalidation** for output like `[{test: , data: }, ...]` +- *derange*: items will be shuffled as *random permutation* (default, `derange: false`) or *random derangement* (`derange: true`) ```javascript const kfold = unirand.kfold; kfold([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 3); // [ [ 9, 8, 2, 10 ], [ 1, 7, 3 ], [ 4, 5, 6 ] ] @@ -168,7 +168,17 @@ kfold([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 3); // [ [ 9, 8, 2, 10 ], [ 1, 7, 3 ], [ kfold([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 3, { type: 'set', derange: true -}); // { '0': [ 8, 10, 7, 1 ], '1': [ 6, 4, 9 ], '2': [ 5, 2, 3 ] } +}); +// { '0': [ 8, 10, 7, 1 ], '1': [ 6, 4, 9 ], '2': [ 5, 2, 3 ] } + +// cross validation +kfold([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 3, { + type: 'crossvalidation', + derange: true +}) +// [ { id: 0, test: [ 5, 6, 9, 7 ], data: [ 4, 1, 10, 2, 8, 3 ] }, +// { id: 1, test: [ 4, 1, 10 ], data: [ 5, 6, 9, 7, 2, 8, 3 ] }, +// { id: 2, test: [ 2, 8, 3 ], data: [ 5, 6, 9, 7, 4, 1, 10 ] } ] ``` For permutation unirand uses seeded PRNG. With *seed* k-fold will always return same result. @@ -197,7 +207,7 @@ Winsorization is the transformation of statistics by limiting extreme values in Parameters: - *input*: array of numbers - *limits*: single number, represent same value trimming value from left and right (should be 0 < limit < 0.5), or an array \[left trim value, right trim value\] (values should be 0 < left trim value < right trim value < 1) -- *mutate*: value (default *true*). If true - mutate ofiginal array, otherwise - no +- *mutate*: value (default *true*). If true - mutate original array, otherwise - no ```javascript const winsorize = unirand.winsorize; diff --git a/core/array_manipulation/kfold.js b/core/array_manipulation/kfold.js index 2cef8cd..583febe 100644 --- a/core/array_manipulation/kfold.js +++ b/core/array_manipulation/kfold.js @@ -9,7 +9,8 @@ import ArrayManipulation from './base'; import Shuffle from './shuffle'; -import type { KFoldOptions, RandomArrayNumberString, RandomArrayStringObject, RandomArrayString } from '../types'; +import type { KFoldOptions, KFoldCrossValidation, RandomArrayNumberString, + RandomArrayStringObject, RandomArrayString } from '../types'; import type { IKFold, IShuffle } from '../interfaces'; class KFold extends ArrayManipulation implements IKFold { @@ -28,7 +29,7 @@ class KFold extends ArrayManipulation implements IKFold { getKFold(input: RandomArrayNumberString, k: number, options: KFoldOptions = { type: 'list', derange: false - }): RandomArrayStringObject { + }): RandomArrayStringObject | KFoldCrossValidation { this._validateInput(input, false); if (typeof k !== 'number') { @@ -39,24 +40,89 @@ class KFold extends ArrayManipulation implements IKFold { throw new Error('k-fold: Parameter "k" should be greater then 0 and less input.length'); } - let result: RandomArrayStringObject; - - if (options.type === 'list') { - result = []; - } else if (options.type === 'set') { - result = {}; - } else { - throw new Error('k-fold: Wrong output type, should be "list" or "set"'); - } - + const folds: Array = this._getFolds(input.length, k); let permutedInput: RandomArrayString; if (!options.derange) { permutedInput = this._shuffle.getPermutation(input); } else { permutedInput = this._shuffle.getDerangement(input); } + + if (options.type === 'list') { + return this._getListSetKFold(permutedInput, folds, []); + } else if (options.type === 'set') { + return this._getListSetKFold(permutedInput, folds, {}); + } else if (options.type === 'crossvalidation') { + return this._getCrossValidationKFold(permutedInput, folds); + } + throw new Error('k-fold: Wrong output type, should be "list", "set" or "crossvalidation"'); + } - const folds: Array = this._getFolds(input.length, k); + /** + * Generates kfold output for "crossvalidation" type + * @param {*} permutedInput + * @param {*} folds + * @param {*} result + */ + _getCrossValidationKFold( + permutedInput: RandomArrayString, + folds: Array + ): KFoldCrossValidation { + const result = []; + const listFolds: RandomArrayStringObject = this._getListSetKFold(permutedInput, folds, []); + for (let i = 0; i < listFolds.length; i += 1) { + result.push({ + id: i, + test: listFolds[i].slice(), + data: this._generateData(listFolds, i) + }); + } + + return result; + } + + /** + * Genarates data for crossvalidation + * Collects all data from all folds except fold[i] + * @param {RandomArrayStringObject} listFolds + * @param {number} i + * @private + */ + _generateData(listFolds: RandomArrayStringObject, i: number): Array> { + const result: Array> = []; + for (let j = 0; j < i; j += 1) { + this._addSubData(listFolds[j], result); + } + for (let j = i + 1; j < listFolds.length; j += 1) { + this._addSubData(listFolds[j], result); + } + + return result; + } + + /** + * @param {RandomArrayStringObject} listFolds + * @param {Array>} result + * @private + */ + _addSubData(listFolds: RandomArrayStringObject, result: Array>): void { + for (let k = 0; k < listFolds.length; k += 1) { + result.push(listFolds[k]); + } + } + + /** + * Generates kfold output for "list" and "set" types + * @param {RandomArrayString} permutedInput + * @param {Array} folds + * @param {RandomArrayStringObject} result + * @private + */ + _getListSetKFold( + permutedInput: RandomArrayString, + folds: Array, + result: RandomArrayStringObject + ): RandomArrayStringObject { let pindex: number = 0; let subResult: RandomArrayNumberString = []; diff --git a/core/types.js b/core/types.js index 1b86816..0b9c5a0 100644 --- a/core/types.js +++ b/core/types.js @@ -73,3 +73,14 @@ export type NumberString = number | string; * Array or number */ export type RandomArrayNumber = RandomArray | number; + +/** + * kfold crossvalidation + */ +export type KFoldCrossValidationItem = { + id: number, + test: RandomArrayStringObject, + data: RandomArrayStringObject +}; + +export type KFoldCrossValidation = Array; diff --git a/package.json b/package.json index 55bc1ae..80852bf 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "unirand", - "version": "2.5.1", + "version": "2.5.2", "description": "Random numbers and Distributions generation", "main": "./lib/index.js", "scripts": { diff --git a/test/array_manipulation.test.js b/test/array_manipulation.test.js index 38fbe5a..19c7641 100644 --- a/test/array_manipulation.test.js +++ b/test/array_manipulation.test.js @@ -786,5 +786,74 @@ describe('Array manipulation methods', () => { } done(); }); + it('should have correct output structure for type "crossvalidation"', () => { + const kfold = new KFold(); + const randomInput = generateInput(); + const res = kfold.getKFold(randomInput, 10, { + type: 'crossvalidation' + }); + + expect(res.length).to.be.equal(10); + expect(Array.isArray(res)).to.be.equal(true); + expect(res[0].id).to.be.equal(0); + expect(res[0].test).to.be.not.equal(undefined); + expect(res[0].data).to.be.not.equal(undefined); + expect(Object.keys(res[0]).length).to.be.equal(3); + expect(Array.isArray(res[0].test)).to.be.equal(true); + expect(Array.isArray(res[0].data)).to.be.equal(true); + for (let i = 0; i < res.length; i += 1) { + expect(res[i].test.length + res[i].data.length).to.be.equal(randomInput.length); + } + }); + it('should generate correct data for type "crossvalidation"', function(done) { + this.timeout(480000); + const kfold = new KFold(); + let input = []; + let res; + const checkExistance = (data, test) => { + let ht = {}; + let fail = false; + for (let i = 0; i < test.length; i += 1) { + ht[test[i]] = 1; + } + for (let i = 0; i < data.length; i += 1) { + if (ht[data[i]]) { + fail = true; + break; + } + } + expect(fail).to.be.equal(false); + }; + + const checkUniqueness = (data, test) => { + const ht = {}; + for (let i = 0; i < data.length; i += 1) { + ht[data[i]] = 1; + } + + for (let i = 0; i < test.length; i += 1) { + ht[test[i]] = 1; + } + + expect(Object.keys(ht).length).to.be.equal(data.length + test.length); + }; + + for (let i = 0; i < 5000; i += 1) { + input[i] = i; + } + + for (let i = 0; i < 400; i += 1) { + res = kfold.getKFold(input, 200, { + type: 'crossvalidation' + }); + + for (let j = 0; j < res.length; j += 1) { + checkUniqueness(res[j].data, res[j].test); + checkExistance(res[j].data, res[j].test); + } + } + + done(); + }); }); }); diff --git a/test/testBuild.js b/test/testBuild.js index 609edd1..a8e48c7 100644 --- a/test/testBuild.js +++ b/test/testBuild.js @@ -6,4 +6,7 @@ let unirand = require('../lib'); unirand.seed(); -console.log(unirand.laplace(1, 2).distributionSync(4)); +console.log(unirand.kfold([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 3, { + derange: true, + type: 'crossvalidation' +}));