forked from tensorflow/tfjs-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sequence_utils.js
64 lines (59 loc) · 2.02 KB
/
sequence_utils.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Utilities for sequential data.
*/
export const PAD_INDEX = 0; // Index of the padding character.
export const OOV_INDEX = 2; // Index fo the OOV character.
/**
* Pad and truncate all sequences to the same length
*
* @param {number[][]} sequences The sequences represented as an array of array
* of numbers.
* @param {number} maxLen Maximum length. Sequences longer than `maxLen` will be
* truncated. Sequences shorter than `maxLen` will be padded.
* @param {'pre'|'post'} padding Padding type.
* @param {'pre'|'post'} truncating Truncation type.
* @param {number} value Padding value.
*/
export function padSequences(
sequences, maxLen, padding = 'pre', truncating = 'pre', value = PAD_INDEX) {
// TODO(cais): This perhaps should be refined and moved into tfjs-preproc.
return sequences.map(seq => {
// Perform truncation.
if (seq.length > maxLen) {
if (truncating === 'pre') {
seq.splice(0, seq.length - maxLen);
} else {
seq.splice(maxLen, seq.length - maxLen);
}
}
// Perform padding.
if (seq.length < maxLen) {
const pad = [];
for (let i = 0; i < maxLen - seq.length; ++i) {
pad.push(value);
}
if (padding === 'pre') {
seq = pad.concat(seq);
} else {
seq = seq.concat(pad);
}
}
return seq;
});
}