From b25709cdab1447fb58271bc26cdb57803514902f Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 21 Aug 2019 16:22:57 +0100 Subject: [PATCH 1/3] Tree and XgBoost tree parser --- .../xpack/ml/MachineLearning.java | 5 +- .../xpack/ml/inference/tree/Tree.java | 215 ++++++++++++++ .../ml/inference/tree/TreeEnsembleModel.java | 111 ++++++++ .../xpack/ml/inference/tree/TreeModel.java | 23 ++ .../ml/inference/tree/TreeModelLoader.java | 96 +++++++ .../tree/xgboost/XgBoostJsonParser.java | 79 ++++++ .../tree/xgboost/XgBoostJsonParserTests.java | 264 ++++++++++++++++++ 7 files changed, 792 insertions(+), 1 deletion(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/Tree.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeEnsembleModel.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeModel.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeModelLoader.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParser.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParserTests.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 8d0d7309e54c6..6c2ca83777870 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -203,6 +203,8 @@ import org.elasticsearch.xpack.ml.inference.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.ModelLoader; import org.elasticsearch.xpack.ml.inference.sillymodel.SillyModelLoader; +import org.elasticsearch.xpack.ml.inference.tree.Tree; +import org.elasticsearch.xpack.ml.inference.tree.TreeModelLoader; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier; @@ -628,7 +630,8 @@ public Map getProcessors(Processor.Parameters paramet } private Map getModelLoaders(Client client) { - return Map.of(SillyModelLoader.MODEL_TYPE, new SillyModelLoader(client)); + return Map.of(SillyModelLoader.MODEL_TYPE, new SillyModelLoader(client), + TreeModelLoader.MODEL_TYPE, new TreeModelLoader(client)); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/Tree.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/Tree.java new file mode 100644 index 0000000000000..108fd5cec2726 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/Tree.java @@ -0,0 +1,215 @@ +package org.elasticsearch.xpack.ml.inference.tree; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.BiPredicate; + + +/** + * A decision tree that can make predictions given a feature vector + */ +public class Tree { + private final List nodes; + + Tree(List nodes) { + this.nodes = Collections.unmodifiableList(nodes); + } + + /** + * Trace the route predicting on the feature vector takes. + * @param features The feature vector + * @return The list of traversed nodes ordered from root to leaf + */ + public List trace(List features) { + return trace(features, 0, new ArrayList<>()); + } + + private List trace(List features, int nodeIndex, List visited) { + Node node = nodes.get(nodeIndex); + visited.add(node); + if (node.isLeaf()) { + return visited; + } + + int nextNode = node.compare(features); + return trace(features, nextNode, visited); + } + + /** + * Make a prediction based on the feature vector + * @param features The feature vector + * @return The prediction + */ + public Double predict(List features) { + return predict(features, 0); + } + + private Double predict(List features, int nodeIndex) { + Node node = nodes.get(nodeIndex); + if (node.isLeaf()) { + return node.value(); + } + + int nextNode = node.compare(features); + return predict(features, nextNode); + } + + /** + * Finds null nodes + * @return List of indexes to null nodes + */ + List missingNodes() { + List nullNodeIndices = new ArrayList<>(); + for (int i=0; i operator; + + Node(int leftChild, int rightChild, int featureIndex, boolean isDefaultLeft, double thresholdValue) { + this.leftChild = leftChild; + this.rightChild = rightChild; + this.featureIndex = featureIndex; + this.isDefaultLeft = isDefaultLeft; + this.thresholdValue = thresholdValue; + this.operator = (value, threshold) -> value < threshold; // less than + } + + Node(int leftChild, int rightChild, int featureIndex, boolean isDefaultLeft, double thresholdValue, + BiPredicate operator) { + this.leftChild = leftChild; + this.rightChild = rightChild; + this.featureIndex = featureIndex; + this.isDefaultLeft = isDefaultLeft; + this.thresholdValue = thresholdValue; + this.operator = operator; + } + + Node(double value) { + this(-1, -1, -1, false, value); + } + + boolean isLeaf() { + return leftChild < 1; + } + + int compare(List features) { + Double feature = features.get(featureIndex); + if (isMissing(feature)) { + return isDefaultLeft ? leftChild : rightChild; + } + + return operator.test(feature, thresholdValue) ? leftChild : rightChild; + } + + boolean isMissing(Double feature) { + return feature == null; + } + + Double value() { + return thresholdValue; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder("{\n"); + builder.append("left: ").append(leftChild).append('\n'); + builder.append("right: ").append(rightChild).append('\n'); + builder.append("isDefaultLeft: ").append(isDefaultLeft).append('\n'); + builder.append("isLeaf: ").append(isLeaf()).append('\n'); + builder.append("featureIndex: ").append(featureIndex).append('\n'); + builder.append("value: ").append(thresholdValue).append('\n'); + builder.append("}\n"); + return builder.toString(); + } + } + + + public static class TreeBuilder { + + private final ArrayList nodes; + private int numNodes; + + public static TreeBuilder newTreeBuilder() { + return new TreeBuilder(); + } + + TreeBuilder() { + nodes = new ArrayList<>(); + // allocate space in the root node and set to a leaf + nodes.add(null); + addLeaf(0, 0.0); + numNodes = 1; + } + + /** + * Add a decision node. Space for the child nodes is allocated + * @param nodeIndex Where to place the node. This is either 0 (root) or an existing child node index + * @param featureIndex The feature index the decision is made on + * @param isDefaultLeft Default left branch if the feature is missing + * @param decisionThreshold The decision threshold + * @return The created node + */ + public Node addJunction(int nodeIndex, int featureIndex, boolean isDefaultLeft, double decisionThreshold) { +// assert nodeIndex < nodes.size() : "node index " + nodeIndex + " >= size " + nodes.size(); + + int leftChild = numNodes++; + int rightChild = numNodes++; + nodes.ensureCapacity(nodeIndex +1); + for (int i=nodes.size(); i trees; + private final Map featureMap; + + private TreeEnsembleModel(List trees, Map featureMap) { + this.trees = Collections.unmodifiableList(trees); + this.featureMap = featureMap; + } + + public int numFeatures() { + return featureMap.size(); + } + + public int numTrees() { + return trees.size(); + } + + public List checkForNull() { + List missing = new ArrayList<>(); + for (Tree tree : trees) { + missing.addAll(tree.missingNodes()); + } + return missing; + } + + public double predictFromDoc(Map features) { + List featureVec = docToFeatureVector(features); + List predictions = trees.stream().map(tree -> tree.predict(featureVec)).collect(Collectors.toList()); + return mergePredictions(predictions); + } + + public double predict(Map features) { + List featureVec = doubleDocToFeatureVector(features); + List predictions = trees.stream().map(tree -> tree.predict(featureVec)).collect(Collectors.toList()); + return mergePredictions(predictions); + } + + public List> trace(Map features) { + List featureVec = doubleDocToFeatureVector(features); + return trees.stream().map(tree -> tree.trace(featureVec)).collect(Collectors.toList()); + } + + double mergePredictions(List predictions) { + return predictions.stream().mapToDouble(f -> f).summaryStatistics().getSum(); + } + + List doubleDocToFeatureVector(Map features) { + List featureVec = Arrays.asList(new Double[featureMap.size()]); + + for (Map.Entry keyValue : features.entrySet()) { + if (featureMap.containsKey(keyValue.getKey())) { + featureVec.set(featureMap.get(keyValue.getKey()), keyValue.getValue()); + } + } + + return featureVec; + } + + List docToFeatureVector(Map features) { + List featureVec = Arrays.asList(new Double[featureMap.size()]); + + for (Map.Entry keyValue : features.entrySet()) { + if (featureMap.containsKey(keyValue.getKey())) { + Double value = (Double)keyValue.getValue(); + if (value != null) { + featureVec.set(featureMap.get(keyValue.getKey()), value); + } + } + } + + return featureVec; + } + + public static ModelBuilder modelBuilder(Map featureMap) { + return new ModelBuilder(featureMap); + } + + public static class ModelBuilder { + private List trees; + private Map featureMap; + + public ModelBuilder(Map featureMap) { + this.featureMap = featureMap; + trees = new ArrayList<>(); + } + + public ModelBuilder addTree(Tree tree) { + trees.add(tree); + return this; + } + + public TreeEnsembleModel build() { + return new TreeEnsembleModel(trees, featureMap); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeModel.java new file mode 100644 index 0000000000000..56e9fa48a039a --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeModel.java @@ -0,0 +1,23 @@ +package org.elasticsearch.xpack.ml.inference.tree; + +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.xpack.ml.inference.Model; + +public class TreeModel implements Model { + + private String targetFieldName; + private TreeEnsembleModel ensemble; + + + TreeModel(TreeEnsembleModel ensemble, String targetFieldName) { + this.ensemble = ensemble; + this.targetFieldName = targetFieldName; + } + + @Override + public IngestDocument infer(IngestDocument document) { + Double prediction = ensemble.predictFromDoc(document.getSourceAndMetadata()); + document.setFieldValue(targetFieldName, prediction); + return document; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeModelLoader.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeModelLoader.java new file mode 100644 index 0000000000000..c69e473c9469c --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeModelLoader.java @@ -0,0 +1,96 @@ +package org.elasticsearch.xpack.ml.inference.tree; + +import org.apache.log4j.LogManager; +import org.apache.log4j.Logger; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.LatchedActionListener; +import org.elasticsearch.client.Client; +import org.elasticsearch.ingest.ConfigurationUtils; +import org.elasticsearch.xpack.ml.inference.InferenceProcessor; +import org.elasticsearch.xpack.ml.inference.ModelLoader; +import org.elasticsearch.xpack.ml.inference.tree.xgboost.XgBoostJsonParser; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +public class TreeModelLoader implements ModelLoader { + + private static final Logger logger = LogManager.getLogger(TreeModelLoader.class); + + public static final String MODEL_TYPE = "tree"; + + private static String INDEX = "index"; + private static String TARGET_FIELD = "target_field"; + + private final Client client; + + public TreeModelLoader(Client client) { + this.client = client; + } + + @Override + public TreeModel load(String modelId, String processorTag, boolean ignoreMissing, Map config) throws Exception { + CountDownLatch latch = new CountDownLatch(1); + AtomicReference modelRef = new AtomicReference<>(); + AtomicReference exception = new AtomicReference<>(); + + LatchedActionListener listener = new LatchedActionListener<>( + ActionListener.wrap(modelRef::set, exception::set), latch + ); + + String index = ConfigurationUtils.readStringProperty(InferenceProcessor.TYPE, processorTag, config, INDEX); + String targetField = ConfigurationUtils.readStringProperty(InferenceProcessor.TYPE, processorTag, config, TARGET_FIELD); + + load(modelId, index, listener); + latch.await(); + if (exception.get() != null) { + throw exception.get(); + } + + TreeEnsembleModel ensemble = modelRef.get(); + TreeModel model = new TreeModel(ensemble, targetField); + + logger.info("loaded model with " + ensemble.numTrees() + " trees"); + return model; + } + + public void load(String id, String index, ActionListener listener) { + client.prepareGet(index, null, id).execute(ActionListener.wrap( + response -> { + if (response.isExists()) { + listener.onResponse(createEnsemble(response.getSourceAsMap())); + } else { + listener.onFailure(new ResourceNotFoundException("missing model [{}], [{}]", id, index)); + } + }, + listener::onFailure + )); + } + + @SuppressWarnings("unchecked") + private TreeEnsembleModel createEnsemble(Map source) throws IOException { + Map featureMap = new HashMap<>(); + featureMap.put("f0", 0); + featureMap.put("f1", 1); + featureMap.put("f2", 2); + featureMap.put("f3", 3); + + List> trees = (List>) source.get("ensemble"); + if (trees == null) { + throw new IllegalStateException("missing trees"); + } + + return XgBoostJsonParser.parse(trees, featureMap); + } + + @Override + public void readConfiguration(String processorTag, Map config) { + ConfigurationUtils.readStringProperty(InferenceProcessor.TYPE, processorTag, config, INDEX); + ConfigurationUtils.readStringProperty(InferenceProcessor.TYPE, processorTag, config, TARGET_FIELD); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParser.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParser.java new file mode 100644 index 0000000000000..b9590259c248d --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParser.java @@ -0,0 +1,79 @@ +package org.elasticsearch.xpack.ml.inference.tree.xgboost; + +import org.elasticsearch.xpack.ml.inference.tree.Tree; +import org.elasticsearch.xpack.ml.inference.tree.TreeEnsembleModel; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public class XgBoostJsonParser { + private static String NODE_ID = "nodeid"; + private static String LEAF = "leaf"; + private static String SPLIT = "split"; + private static String SPLIT_CONDITION = "split_condition"; + private static String CHILDREN = "children"; + +// private static Logger logger = LogManager.getLogger(XgBoostJsonParser.class); + + + public static TreeEnsembleModel parse(List> json, Map featureMap) throws IOException { + + TreeEnsembleModel.ModelBuilder modelBuilder = TreeEnsembleModel.modelBuilder(featureMap); + + for (Map treeMap : json) { + Tree.TreeBuilder treeBuilder = Tree.TreeBuilder.newTreeBuilder(); + parseTree(treeMap, featureMap, treeBuilder); + modelBuilder.addTree(treeBuilder.build()); + } + + return modelBuilder.build(); + } + + + @SuppressWarnings("unchecked") + private static void parseTree(Map treeMap, Map featureMap, + Tree.TreeBuilder treeBuilder) throws IOException { + + if (treeMap.containsKey(NODE_ID) == false) { + throw new IOException("invalid object does not contain a nodeid"); + } + + Integer nodeId = (Integer)treeMap.get(NODE_ID); + + boolean isLeaf = treeMap.containsKey(LEAF); + if (isLeaf) { + Double value = (Double) treeMap.get(LEAF); + treeBuilder.addLeaf(nodeId, value); + } else { + List> children = (List>) treeMap.get(CHILDREN); + if (children == null) { + throw new IOException("none leaf node has no children"); + } + + if (children.size() != 2) { + throw new IOException("Node must have exactly 2 children, got " + children.size()); + } + + String feature = (String)treeMap.get(SPLIT); + if (feature == null) { + throw new IOException("missing feature for node [" + nodeId + "]"); + } + + Integer featureIndex = featureMap.get(feature); + if (featureIndex == null) { + throw new IOException("feature [" + feature + "] not found in the feature map"); + } + Double threshold = (Double)treeMap.get(SPLIT_CONDITION); + if (threshold == null) { + throw new IOException("field [SPLIT_CONDITION] in node [" + nodeId + "] is not a double"); + } + + treeBuilder.addJunction(nodeId, featureIndex, true, threshold); + + for (Map child : children) { + parseTree(child, featureMap, treeBuilder); + } + } + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParserTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParserTests.java new file mode 100644 index 0000000000000..ebe755fdfbbf4 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParserTests.java @@ -0,0 +1,264 @@ +package org.elasticsearch.xpack.ml.inference.tree.xgboost; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.ml.inference.tree.Tree; +import org.elasticsearch.xpack.ml.inference.tree.TreeEnsembleModel; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; + +public class XgBoostJsonParserTests extends ESTestCase { + public void testSimpleParse() throws IOException { + String json = + "{ \"nodeid\": 0, \"depth\": 0, \"split\": \"f1\", \"split_condition\": -9.53674316e-07, \"yes\": 1, \"no\": 2, \"missing\": 1, \"children\": [\n" + + " { \"nodeid\": 1, \"depth\": 1, \"split\": \"f2\", \"split_condition\": -9.53674316e-07, \"yes\": 3, \"no\": 4, \"missing\": 3, \"children\": [\n" + + " { \"nodeid\": 3, \"leaf\": 0.78471756 },\n" + + " { \"nodeid\": 4, \"leaf\": -0.968530357 }\n" + + " ]},\n" + + " { \"nodeid\": 2, \"leaf\": -6.23624468 }\n" + + " ]}"; + + + Map featureMap = new HashMap<>(); + featureMap.put("f1", 0); + featureMap.put("f2", 1); + + TreeEnsembleModel model; + try (XContentParser parser = createParser(JsonXContent.jsonXContent, json)) { + Map map = parser.map(); + + model = XgBoostJsonParser.parse(Collections.singletonList(map), featureMap); + + Map doc = new HashMap<>(); + doc.put("f1", 0.1); + doc.put("f2", 0.1); + List> trace = model.trace(doc); + System.out.println(trace); + + double prediction = model.predict(doc); + System.out.println(prediction); + } + } + + @SuppressWarnings("unchecked") + public void testParse() throws IOException { + String json = "{\"ensemble\": [\n" + + " { \"nodeid\": 0, \"depth\": 0, \"split\": \"f0\", \"split_condition\": 18.3800011, \"yes\": 1, \"no\": 2, \"missing\": 1, \"children\": [\n" + + " { \"nodeid\": 1, \"leaf\": 470.438141 },\n" + + " { \"nodeid\": 2, \"leaf\": 441.225525 }\n" + + " ]},\n" + + " { \"nodeid\": 0, \"depth\": 0, \"split\": \"f0\", \"split_condition\": 10.9049997, \"yes\": 1, \"no\": 2, \"missing\": 1, \"children\": [\n" + + " { \"nodeid\": 1, \"depth\": 1, \"split\": \"f0\", \"split_condition\": 8.65499973, \"yes\": 3, \"no\": 4, \"missing\": 3, \"children\": [\n" + + " { \"nodeid\": 3, \"depth\": 2, \"split\": \"f0\", \"split_condition\": 7.04500008, \"yes\": 7, \"no\": 8, \"missing\": 7, \"children\": [\n" + + " { \"nodeid\": 7, \"depth\": 3, \"split\": \"f0\", \"split_condition\": 4.77499962, \"yes\": 15, \"no\": 16, \"missing\": 15, \"children\": [\n" + + " { \"nodeid\": 15, \"leaf\": 17.8773727 },\n" + + " { \"nodeid\": 16, \"depth\": 4, \"split\": \"f3\", \"split_condition\": 77.4000015, \"yes\": 31, \"no\": 32, \"missing\": 31, \"children\": [\n" + + " { \"nodeid\": 31, \"leaf\": 16.7193699 },\n" + + " { \"nodeid\": 32, \"depth\": 5, \"split\": \"f3\", \"split_condition\": 78.9400024, \"yes\": 61, \"no\": 62, \"missing\": 61, \"children\": [\n" + + " { \"nodeid\": 61, \"leaf\": 8.74667072 },\n" + + " { \"nodeid\": 62, \"leaf\": 13.6302109 }\n" + + " ]}\n" + + " ]}\n" + + " ]},\n" + + " { \"nodeid\": 8, \"depth\": 3, \"split\": \"f2\", \"split_condition\": 1022.72498, \"yes\": 17, \"no\": 18, \"missing\": 17, \"children\": [\n" + + " { \"nodeid\": 17, \"depth\": 4, \"split\": \"f1\", \"split_condition\": 57.9249992, \"yes\": 33, \"no\": 34, \"missing\": 33, \"children\": [\n" + + " { \"nodeid\": 33, \"depth\": 5, \"split\": \"f0\", \"split_condition\": 8.03999996, \"yes\": 63, \"no\": 64, \"missing\": 63, \"children\": [\n" + + " { \"nodeid\": 63, \"leaf\": 12.3259125 },\n" + + " { \"nodeid\": 64, \"leaf\": 10.1782122 }\n" + + " ]},\n" + + " { \"nodeid\": 34, \"leaf\": 1.59123743 }\n" + + " ]},\n" + + " { \"nodeid\": 18, \"depth\": 4, \"split\": \"f3\", \"split_condition\": 89.6999969, \"yes\": 35, \"no\": 36, \"missing\": 35, \"children\": [\n" + + " { \"nodeid\": 35, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 47.3199997, \"yes\": 65, \"no\": 66, \"missing\": 65, \"children\": [\n" + + " { \"nodeid\": 65, \"leaf\": 9.94041252 },\n" + + " { \"nodeid\": 66, \"leaf\": 1.25593567 }\n" + + " ]},\n" + + " { \"nodeid\": 36, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 41.5400009, \"yes\": 67, \"no\": 68, \"missing\": 67, \"children\": [\n" + + " { \"nodeid\": 67, \"leaf\": 5.41526508 },\n" + + " { \"nodeid\": 68, \"leaf\": -1.86407471 }\n" + + " ]}\n" + + " ]}\n" + + " ]}\n" + + " ]},\n" + + " { \"nodeid\": 4, \"depth\": 2, \"split\": \"f3\", \"split_condition\": 92.9850006, \"yes\": 9, \"no\": 10, \"missing\": 9, \"children\": [\n" + + " { \"nodeid\": 9, \"depth\": 3, \"split\": \"f1\", \"split_condition\": 48.8600006, \"yes\": 19, \"no\": 20, \"missing\": 19, \"children\": [\n" + + " { \"nodeid\": 19, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 9.78999996, \"yes\": 37, \"no\": 38, \"missing\": 37, \"children\": [\n" + + " { \"nodeid\": 37, \"depth\": 5, \"split\": \"f3\", \"split_condition\": 62.8600006, \"yes\": 69, \"no\": 70, \"missing\": 69, \"children\": [\n" + + " { \"nodeid\": 69, \"leaf\": 2.36037445 },\n" + + " { \"nodeid\": 70, \"leaf\": 7.76275396 }\n" + + " ]},\n" + + " { \"nodeid\": 38, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 38.2799988, \"yes\": 71, \"no\": 72, \"missing\": 71, \"children\": [\n" + + " { \"nodeid\": 71, \"leaf\": 3.42786431 },\n" + + " { \"nodeid\": 72, \"leaf\": 6.49917078 }\n" + + " ]}\n" + + " ]},\n" + + " { \"nodeid\": 20, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 10.3050003, \"yes\": 39, \"no\": 40, \"missing\": 39, \"children\": [\n" + + " { \"nodeid\": 39, \"depth\": 5, \"split\": \"f0\", \"split_condition\": 9.76499939, \"yes\": 73, \"no\": 74, \"missing\": 73, \"children\": [\n" + + " { \"nodeid\": 73, \"leaf\": 1.2359314 },\n" + + " { \"nodeid\": 74, \"leaf\": 0.0513916016 }\n" + + " ]},\n" + + " { \"nodeid\": 40, \"leaf\": -2.12875366 }\n" + + " ]}\n" + + " ]},\n" + + " { \"nodeid\": 10, \"depth\": 3, \"split\": \"f1\", \"split_condition\": 41.7750015, \"yes\": 21, \"no\": 22, \"missing\": 21, \"children\": [\n" + + " { \"nodeid\": 21, \"depth\": 4, \"split\": \"f1\", \"split_condition\": 41.5350037, \"yes\": 41, \"no\": 42, \"missing\": 41, \"children\": [\n" + + " { \"nodeid\": 41, \"depth\": 5, \"split\": \"f2\", \"split_condition\": 1017.32001, \"yes\": 75, \"no\": 76, \"missing\": 75, \"children\": [\n" + + " { \"nodeid\": 75, \"leaf\": 3.39938235 },\n" + + " { \"nodeid\": 76, \"leaf\": -0.816912234 }\n" + + " ]},\n" + + " { \"nodeid\": 42, \"depth\": 5, \"split\": \"f2\", \"split_condition\": 1021.21497, \"yes\": 77, \"no\": 78, \"missing\": 77, \"children\": [\n" + + " { \"nodeid\": 77, \"leaf\": -6.331038 },\n" + + " { \"nodeid\": 78, \"leaf\": -1.37406921 }\n" + + " ]}\n" + + " ]},\n" + + " { \"nodeid\": 22, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 9.25, \"yes\": 43, \"no\": 44, \"missing\": 43, \"children\": [\n" + + " { \"nodeid\": 43, \"leaf\": 7.57445431 },\n" + + " { \"nodeid\": 44, \"depth\": 5, \"split\": \"f3\", \"split_condition\": 99.7200012, \"yes\": 79, \"no\": 80, \"missing\": 79, \"children\": [\n" + + " { \"nodeid\": 79, \"leaf\": 3.53175259 },\n" + + " { \"nodeid\": 80, \"leaf\": 0.886398315 }\n" + + " ]}\n" + + " ]}\n" + + " ]}\n" + + " ]}\n" + + " ]},\n" + + " { \"nodeid\": 2, \"depth\": 1, \"split\": \"f1\", \"split_condition\": 66.1500015, \"yes\": 5, \"no\": 6, \"missing\": 5, \"children\": [\n" + + " { \"nodeid\": 5, \"depth\": 2, \"split\": \"f0\", \"split_condition\": 18.3850002, \"yes\": 11, \"no\": 12, \"missing\": 11, \"children\": [\n" + + " { \"nodeid\": 11, \"depth\": 3, \"split\": \"f0\", \"split_condition\": 14.4449997, \"yes\": 23, \"no\": 24, \"missing\": 23, \"children\": [\n" + + " { \"nodeid\": 23, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 12.6049995, \"yes\": 45, \"no\": 46, \"missing\": 45, \"children\": [\n" + + " { \"nodeid\": 45, \"depth\": 5, \"split\": \"f3\", \"split_condition\": 88.3000031, \"yes\": 81, \"no\": 82, \"missing\": 81, \"children\": [\n" + + " { \"nodeid\": 81, \"leaf\": 2.6010797 },\n" + + " { \"nodeid\": 82, \"leaf\": -0.530121922 }\n" + + " ]},\n" + + " { \"nodeid\": 46, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 40.9049988, \"yes\": 83, \"no\": 84, \"missing\": 83, \"children\": [\n" + + " { \"nodeid\": 83, \"leaf\": -1.30897391 },\n" + + " { \"nodeid\": 84, \"leaf\": -3.96863937 }\n" + + " ]}\n" + + " ]},\n" + + " { \"nodeid\": 24, \"depth\": 4, \"split\": \"f1\", \"split_condition\": 45.0049973, \"yes\": 47, \"no\": 48, \"missing\": 47, \"children\": [\n" + + " { \"nodeid\": 47, \"depth\": 5, \"split\": \"f0\", \"split_condition\": 15.4849997, \"yes\": 85, \"no\": 86, \"missing\": 85, \"children\": [\n" + + " { \"nodeid\": 85, \"leaf\": -5.40670681 },\n" + + " { \"nodeid\": 86, \"leaf\": -8.27983284 }\n" + + " ]},\n" + + " { \"nodeid\": 48, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 56.6899986, \"yes\": 87, \"no\": 88, \"missing\": 87, \"children\": [\n" + + " { \"nodeid\": 87, \"leaf\": -12.0328913 },\n" + + " { \"nodeid\": 88, \"leaf\": -17.0602112 }\n" + + " ]}\n" + + " ]}\n" + + " ]},\n" + + " { \"nodeid\": 12, \"depth\": 3, \"split\": \"f0\", \"split_condition\": 21.5450001, \"yes\": 25, \"no\": 26, \"missing\": 25, \"children\": [\n" + + " { \"nodeid\": 25, \"depth\": 4, \"split\": \"f1\", \"split_condition\": 45.7399979, \"yes\": 49, \"no\": 50, \"missing\": 49, \"children\": [\n" + + " { \"nodeid\": 49, \"depth\": 5, \"split\": \"f3\", \"split_condition\": 60.5650024, \"yes\": 89, \"no\": 90, \"missing\": 89, \"children\": [\n" + + " { \"nodeid\": 89, \"leaf\": 18.9551449 },\n" + + " { \"nodeid\": 90, \"leaf\": 13.8164215 }\n" + + " ]},\n" + + " { \"nodeid\": 50, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 58.1699982, \"yes\": 91, \"no\": 92, \"missing\": 91, \"children\": [\n" + + " { \"nodeid\": 91, \"leaf\": 10.819478 },\n" + + " { \"nodeid\": 92, \"leaf\": 7.62072706 }\n" + + " ]}\n" + + " ]},\n" + + " { \"nodeid\": 26, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 26.4650002, \"yes\": 51, \"no\": 52, \"missing\": 51, \"children\": [\n" + + " { \"nodeid\": 51, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 55.7350006, \"yes\": 93, \"no\": 94, \"missing\": 93, \"children\": [\n" + + " { \"nodeid\": 93, \"leaf\": 5.94795942 },\n" + + " { \"nodeid\": 94, \"leaf\": 2.39619088 }\n" + + " ]},\n" + + " { \"nodeid\": 52, \"depth\": 5, \"split\": \"f2\", \"split_condition\": 1007.45496, \"yes\": 95, \"no\": 96, \"missing\": 95, \"children\": [\n" + + " { \"nodeid\": 95, \"leaf\": -5.02249479 },\n" + + " { \"nodeid\": 96, \"leaf\": -0.807426214 }\n" + + " ]}\n" + + " ]}\n" + + " ]}\n" + + " ]},\n" + + " { \"nodeid\": 6, \"depth\": 2, \"split\": \"f0\", \"split_condition\": 25.2750015, \"yes\": 13, \"no\": 14, \"missing\": 13, \"children\": [\n" + + " { \"nodeid\": 13, \"depth\": 3, \"split\": \"f0\", \"split_condition\": 18.3549995, \"yes\": 27, \"no\": 28, \"missing\": 27, \"children\": [\n" + + " { \"nodeid\": 27, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 14.4949999, \"yes\": 53, \"no\": 54, \"missing\": 53, \"children\": [\n" + + " { \"nodeid\": 53, \"leaf\": -6.55451679 },\n" + + " { \"nodeid\": 54, \"leaf\": -16.1782494 }\n" + + " ]},\n" + + " { \"nodeid\": 28, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 22.5400009, \"yes\": 55, \"no\": 56, \"missing\": 55, \"children\": [\n" + + " { \"nodeid\": 55, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 69.0449982, \"yes\": 97, \"no\": 98, \"missing\": 97, \"children\": [\n" + + " { \"nodeid\": 97, \"leaf\": 5.74967909 },\n" + + " { \"nodeid\": 98, \"leaf\": 0.178715631 }\n" + + " ]},\n" + + " { \"nodeid\": 56, \"depth\": 5, \"split\": \"f3\", \"split_condition\": 68.5200043, \"yes\": 99, \"no\": 100, \"missing\": 99, \"children\": [\n" + + " { \"nodeid\": 99, \"leaf\": 0.774093628 },\n" + + " { \"nodeid\": 100, \"leaf\": -3.68317914 }\n" + + " ]}\n" + + " ]}\n" + + " ]},\n" + + " { \"nodeid\": 14, \"depth\": 3, \"split\": \"f2\", \"split_condition\": 1008.78003, \"yes\": 29, \"no\": 30, \"missing\": 29, \"children\": [\n" + + " { \"nodeid\": 29, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 26.3549995, \"yes\": 57, \"no\": 58, \"missing\": 57, \"children\": [\n" + + " { \"nodeid\": 57, \"depth\": 5, \"split\": \"f3\", \"split_condition\": 75.3950043, \"yes\": 101, \"no\": 102, \"missing\": 101, \"children\": [\n" + + " { \"nodeid\": 101, \"leaf\": -3.79112887 },\n" + + " { \"nodeid\": 102, \"leaf\": -7.66311884 }\n" + + " ]},\n" + + " { \"nodeid\": 58, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 73.7200012, \"yes\": 103, \"no\": 104, \"missing\": 103, \"children\": [\n" + + " { \"nodeid\": 103, \"leaf\": -9.83749485 },\n" + + " { \"nodeid\": 104, \"leaf\": -7.75816107 }\n" + + " ]}\n" + + " ]},\n" + + " { \"nodeid\": 30, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 29.2350006, \"yes\": 59, \"no\": 60, \"missing\": 59, \"children\": [\n" + + " { \"nodeid\": 59, \"depth\": 5, \"split\": \"f3\", \"split_condition\": 71.4450073, \"yes\": 105, \"no\": 106, \"missing\": 105, \"children\": [\n" + + " { \"nodeid\": 105, \"leaf\": -4.53095436 },\n" + + " { \"nodeid\": 106, \"leaf\": -7.3271451 }\n" + + " ]},\n" + + " { \"nodeid\": 60, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 67.8099976, \"yes\": 107, \"no\": 108, \"missing\": 107, \"children\": [\n" + + " { \"nodeid\": 107, \"leaf\": -6.15153503 },\n" + + " { \"nodeid\": 108, \"leaf\": -8.42399311 }\n" + + " ]}\n" + + " ]}\n" + + " ]}\n" + + " ]}\n" + + " ]}\n" + + " ]}\n" + + "]" + + "}"; + + + Map featureMap = new HashMap<>(); + featureMap.put("f0", 0); + featureMap.put("f1", 1); + featureMap.put("f2", 2); + featureMap.put("f3", 3); + + TreeEnsembleModel model; + try (XContentParser parser = createParser(JsonXContent.jsonXContent, json)) { + Map map = parser.map(); + + List> ensemble = (List>) map.get("ensemble"); + + model = XgBoostJsonParser.parse(ensemble, featureMap); + + assertEquals(2, model.numTrees()); + +// List missingNodes = model.checkForNull(); +// System.out.println(missingNodes); +// assertThat(missingNodes, hasSize(0)); + + + Map doc = new HashMap<>(); + doc.put("f0", 10.36); + doc.put("f1", 43.67); + doc.put("f2", 1012.1); + doc.put("f3", 77.04); + List> trace = model.trace(doc); + System.out.println(trace); + logger.info(trace); + + double prediction = model.predict(doc); + System.out.println(prediction); + logger.info("prediction: " + prediction); + assertThat(prediction, is(greaterThan(450.0))); + assertThat(prediction, is(lessThan(500.0))); + } + } +} From e6edcd50336b8b18a4d9f06e9e2e35d9e53f4614 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 29 Aug 2019 18:14:13 +0100 Subject: [PATCH 2/3] Checkstyle fixes --- .../xpack/ml/MachineLearning.java | 7 +- .../xpack/ml/inference/tree/Tree.java | 22 +- .../ml/inference/tree/TreeEnsembleModel.java | 6 + .../xpack/ml/inference/tree/TreeModel.java | 12 +- .../ml/inference/tree/TreeModelLoader.java | 32 ++- .../tree/xgboost/XgBoostJsonParser.java | 15 +- .../ml/inference/InferenceProcessorTests.java | 1 - .../tree/xgboost/XgBoostJsonParserTests.java | 225 ++++++++++++------ 8 files changed, 228 insertions(+), 92 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 6c2ca83777870..a2ed592248703 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -195,15 +195,14 @@ import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager; -import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.MemoryUsageEstimationProcessManager; -import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; -import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.NativeAnalyticsProcessFactory; +import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory; +import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; +import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; import org.elasticsearch.xpack.ml.inference.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.ModelLoader; import org.elasticsearch.xpack.ml.inference.sillymodel.SillyModelLoader; -import org.elasticsearch.xpack.ml.inference.tree.Tree; import org.elasticsearch.xpack.ml.inference.tree.TreeModelLoader; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/Tree.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/Tree.java index 108fd5cec2726..506f7983d27be 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/Tree.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/Tree.java @@ -1,3 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + package org.elasticsearch.xpack.ml.inference.tree; import java.util.ArrayList; @@ -105,7 +111,7 @@ public static class Node { this(-1, -1, -1, false, value); } - boolean isLeaf() { + public boolean isLeaf() { return leftChild < 1; } @@ -122,10 +128,22 @@ boolean isMissing(Double feature) { return feature == null; } - Double value() { + public Double value() { return thresholdValue; } + public int getFeatureIndex() { + return featureIndex; + } + + public int getLeftChild() { + return leftChild; + } + + public int getRightChild() { + return rightChild; + } + @Override public String toString() { StringBuilder builder = new StringBuilder("{\n"); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeEnsembleModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeEnsembleModel.java index f1fcfb70fb422..e9688bc3b7c5f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeEnsembleModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeEnsembleModel.java @@ -1,3 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + package org.elasticsearch.xpack.ml.inference.tree; import org.apache.log4j.LogManager; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeModel.java index 56e9fa48a039a..9481deaad81de 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeModel.java @@ -1,8 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + package org.elasticsearch.xpack.ml.inference.tree; import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.xpack.ml.inference.Model; +import java.util.function.BiConsumer; + public class TreeModel implements Model { private String targetFieldName; @@ -15,9 +23,9 @@ public class TreeModel implements Model { } @Override - public IngestDocument infer(IngestDocument document) { + public void infer(IngestDocument document, BiConsumer handler) { Double prediction = ensemble.predictFromDoc(document.getSourceAndMetadata()); document.setFieldValue(targetFieldName, prediction); - return document; + handler.accept(document, null); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeModelLoader.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeModelLoader.java index c69e473c9469c..50c2c3dde5e57 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeModelLoader.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeModelLoader.java @@ -1,3 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + package org.elasticsearch.xpack.ml.inference.tree; import org.apache.log4j.LogManager; @@ -46,7 +52,7 @@ public TreeModel load(String modelId, String processorTag, boolean ignoreMissing String index = ConfigurationUtils.readStringProperty(InferenceProcessor.TYPE, processorTag, config, INDEX); String targetField = ConfigurationUtils.readStringProperty(InferenceProcessor.TYPE, processorTag, config, TARGET_FIELD); - load(modelId, index, listener); + load(modelId, index, config, listener); latch.await(); if (exception.get() != null) { throw exception.get(); @@ -59,11 +65,17 @@ public TreeModel load(String modelId, String processorTag, boolean ignoreMissing return model; } - public void load(String id, String index, ActionListener listener) { + @Override + public void consumeConfiguration(String processorTag, Map config) { + ConfigurationUtils.readStringProperty(InferenceProcessor.TYPE, processorTag, config, INDEX); + ConfigurationUtils.readStringProperty(InferenceProcessor.TYPE, processorTag, config, TARGET_FIELD); + } + + public void load(String id, String index, Map config, ActionListener listener) { client.prepareGet(index, null, id).execute(ActionListener.wrap( response -> { if (response.isExists()) { - listener.onResponse(createEnsemble(response.getSourceAsMap())); + listener.onResponse(createEnsemble(response.getSourceAsMap(), featureMapFromConfig(config))); } else { listener.onFailure(new ResourceNotFoundException("missing model [{}], [{}]", id, index)); } @@ -72,14 +84,18 @@ public void load(String id, String index, ActionListener list )); } - @SuppressWarnings("unchecked") - private TreeEnsembleModel createEnsemble(Map source) throws IOException { + private Map featureMapFromConfig(Map config) { + // TODO, this was hard coded for a demo Map featureMap = new HashMap<>(); featureMap.put("f0", 0); featureMap.put("f1", 1); featureMap.put("f2", 2); featureMap.put("f3", 3); + return featureMap; + } + @SuppressWarnings("unchecked") + private TreeEnsembleModel createEnsemble(Map source, Map featureMap) throws IOException { List> trees = (List>) source.get("ensemble"); if (trees == null) { throw new IllegalStateException("missing trees"); @@ -87,10 +103,4 @@ private TreeEnsembleModel createEnsemble(Map source) throws IOEx return XgBoostJsonParser.parse(trees, featureMap); } - - @Override - public void readConfiguration(String processorTag, Map config) { - ConfigurationUtils.readStringProperty(InferenceProcessor.TYPE, processorTag, config, INDEX); - ConfigurationUtils.readStringProperty(InferenceProcessor.TYPE, processorTag, config, TARGET_FIELD); - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParser.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParser.java index b9590259c248d..8241de7315ade 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParser.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParser.java @@ -1,3 +1,10 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + + package org.elasticsearch.xpack.ml.inference.tree.xgboost; import org.elasticsearch.xpack.ml.inference.tree.Tree; @@ -14,8 +21,6 @@ public class XgBoostJsonParser { private static String SPLIT_CONDITION = "split_condition"; private static String CHILDREN = "children"; -// private static Logger logger = LogManager.getLogger(XgBoostJsonParser.class); - public static TreeEnsembleModel parse(List> json, Map featureMap) throws IOException { @@ -69,6 +74,12 @@ private static void parseTree(Map treeMap, Map throw new IOException("field [SPLIT_CONDITION] in node [" + nodeId + "] is not a double"); } + // Fortunately XGBoost numbers the nodes the same way we do (breath first) + // Assuming a 1 based index a non-leaf node at index X will have its child nodes + // at indices 2X and 2X +1 (in practice we are using a zero based index so adjust + // accordingly). XGBoost uses the same numbering system so we don't have to point + // the child indices somewhere else on the new node, the child nodes will + // automatically slot in the in right place. treeBuilder.addJunction(nodeId, featureIndex, true, threshold); for (Map child : children) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/InferenceProcessorTests.java index c0767ec950dc5..a844b893f2339 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/InferenceProcessorTests.java @@ -24,7 +24,6 @@ public class InferenceProcessorTests extends ESTestCase { public void testFactory() throws Exception { String processorTag = randomAlphaOfLength(10); String modelType = "test"; - Model mockModel = mock(Model.class); ModelLoader mockLoader = mock(ModelLoader.class); when(mockLoader.load(eq("k2"), eq(processorTag), eq(false), any())).thenReturn(mockModel); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParserTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParserTests.java index ebe755fdfbbf4..1e60434234527 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParserTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParserTests.java @@ -1,3 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + package org.elasticsearch.xpack.ml.inference.tree.xgboost; import org.elasticsearch.common.xcontent.XContentParser; @@ -13,15 +19,16 @@ import java.util.Map; import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.lessThan; public class XgBoostJsonParserTests extends ESTestCase { public void testSimpleParse() throws IOException { String json = - "{ \"nodeid\": 0, \"depth\": 0, \"split\": \"f1\", \"split_condition\": -9.53674316e-07, \"yes\": 1, \"no\": 2, \"missing\": 1, \"children\": [\n" + - " { \"nodeid\": 1, \"depth\": 1, \"split\": \"f2\", \"split_condition\": -9.53674316e-07, \"yes\": 3, \"no\": 4, \"missing\": 3, \"children\": [\n" + + "{ \"nodeid\": 0, \"depth\": 0, \"split\": \"f1\", " + + "\"split_condition\": -9.53674316e-07, \"yes\": 1, \"no\": 2, \"missing\": 1, \"children\": [\n" + + " { \"nodeid\": 1, \"depth\": 1, \"split\": \"f2\", " + + "\"split_condition\": -9.53674316e-07, \"yes\": 3, \"no\": 4, \"missing\": 3, \"children\": [\n" + " { \"nodeid\": 3, \"leaf\": 0.78471756 },\n" + " { \"nodeid\": 4, \"leaf\": -0.968530357 }\n" + " ]},\n" + @@ -43,87 +50,117 @@ public void testSimpleParse() throws IOException { doc.put("f1", 0.1); doc.put("f2", 0.1); List> trace = model.trace(doc); - System.out.println(trace); + assertThat(trace, hasSize(1)); + List treeTrace = trace.get(0); + assertThat(treeTrace, hasSize(2)); + assertFalse(treeTrace.get(0).isLeaf()); + assertTrue(treeTrace.get(1).isLeaf()); double prediction = model.predict(doc); - System.out.println(prediction); + assertEquals(-6.23624468, prediction, 0.0001); } } @SuppressWarnings("unchecked") public void testParse() throws IOException { String json = "{\"ensemble\": [\n" + - " { \"nodeid\": 0, \"depth\": 0, \"split\": \"f0\", \"split_condition\": 18.3800011, \"yes\": 1, \"no\": 2, \"missing\": 1, \"children\": [\n" + + " { \"nodeid\": 0, \"depth\": 0, \"split\": \"f0\", " + + "\"split_condition\": 18.3800011, \"yes\": 1, \"no\": 2, \"missing\": 1, \"children\": [\n" + " { \"nodeid\": 1, \"leaf\": 470.438141 },\n" + " { \"nodeid\": 2, \"leaf\": 441.225525 }\n" + " ]},\n" + - " { \"nodeid\": 0, \"depth\": 0, \"split\": \"f0\", \"split_condition\": 10.9049997, \"yes\": 1, \"no\": 2, \"missing\": 1, \"children\": [\n" + - " { \"nodeid\": 1, \"depth\": 1, \"split\": \"f0\", \"split_condition\": 8.65499973, \"yes\": 3, \"no\": 4, \"missing\": 3, \"children\": [\n" + - " { \"nodeid\": 3, \"depth\": 2, \"split\": \"f0\", \"split_condition\": 7.04500008, \"yes\": 7, \"no\": 8, \"missing\": 7, \"children\": [\n" + - " { \"nodeid\": 7, \"depth\": 3, \"split\": \"f0\", \"split_condition\": 4.77499962, \"yes\": 15, \"no\": 16, \"missing\": 15, \"children\": [\n" + + " { \"nodeid\": 0, \"depth\": 0, \"split\": \"f0\", " + + "\"split_condition\": 10.9049997, \"yes\": 1, \"no\": 2, \"missing\": 1, \"children\": [\n" + + " { \"nodeid\": 1, \"depth\": 1, \"split\": \"f0\", " + + "\"split_condition\": 8.65499973, \"yes\": 3, \"no\": 4, \"missing\": 3, \"children\": [\n" + + " { \"nodeid\": 3, \"depth\": 2, \"split\": \"f0\", " + + "\"split_condition\": 7.04500008, \"yes\": 7, \"no\": 8, \"missing\": 7, \"children\": [\n" + + " { \"nodeid\": 7, \"depth\": 3, \"split\": \"f0\", " + + "\"split_condition\": 4.77499962, \"yes\": 15, \"no\": 16, \"missing\": 15, \"children\": [\n" + " { \"nodeid\": 15, \"leaf\": 17.8773727 },\n" + - " { \"nodeid\": 16, \"depth\": 4, \"split\": \"f3\", \"split_condition\": 77.4000015, \"yes\": 31, \"no\": 32, \"missing\": 31, \"children\": [\n" + + " { \"nodeid\": 16, \"depth\": 4, \"split\": \"f3\", " + + "\"split_condition\": 77.4000015, \"yes\": 31, \"no\": 32, \"missing\": 31, \"children\": [\n" + " { \"nodeid\": 31, \"leaf\": 16.7193699 },\n" + - " { \"nodeid\": 32, \"depth\": 5, \"split\": \"f3\", \"split_condition\": 78.9400024, \"yes\": 61, \"no\": 62, \"missing\": 61, \"children\": [\n" + + " { \"nodeid\": 32, \"depth\": 5, \"split\": \"f3\", " + + "\"split_condition\": 78.9400024, \"yes\": 61, \"no\": 62, \"missing\": 61, \"children\": [\n" + " { \"nodeid\": 61, \"leaf\": 8.74667072 },\n" + " { \"nodeid\": 62, \"leaf\": 13.6302109 }\n" + " ]}\n" + " ]}\n" + " ]},\n" + - " { \"nodeid\": 8, \"depth\": 3, \"split\": \"f2\", \"split_condition\": 1022.72498, \"yes\": 17, \"no\": 18, \"missing\": 17, \"children\": [\n" + - " { \"nodeid\": 17, \"depth\": 4, \"split\": \"f1\", \"split_condition\": 57.9249992, \"yes\": 33, \"no\": 34, \"missing\": 33, \"children\": [\n" + - " { \"nodeid\": 33, \"depth\": 5, \"split\": \"f0\", \"split_condition\": 8.03999996, \"yes\": 63, \"no\": 64, \"missing\": 63, \"children\": [\n" + + " { \"nodeid\": 8, \"depth\": 3, \"split\": \"f2\", " + + "\"split_condition\": 1022.72498, \"yes\": 17, \"no\": 18, \"missing\": 17, \"children\": [\n" + + " { \"nodeid\": 17, \"depth\": 4, \"split\": \"f1\", " + + "\"split_condition\": 57.9249992, \"yes\": 33, \"no\": 34, \"missing\": 33, \"children\": [\n" + + " { \"nodeid\": 33, \"depth\": 5, \"split\": \"f0\", " + + "\"split_condition\": 8.03999996, \"yes\": 63, \"no\": 64, \"missing\": 63, \"children\": [\n" + " { \"nodeid\": 63, \"leaf\": 12.3259125 },\n" + " { \"nodeid\": 64, \"leaf\": 10.1782122 }\n" + " ]},\n" + " { \"nodeid\": 34, \"leaf\": 1.59123743 }\n" + " ]},\n" + - " { \"nodeid\": 18, \"depth\": 4, \"split\": \"f3\", \"split_condition\": 89.6999969, \"yes\": 35, \"no\": 36, \"missing\": 35, \"children\": [\n" + - " { \"nodeid\": 35, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 47.3199997, \"yes\": 65, \"no\": 66, \"missing\": 65, \"children\": [\n" + + " { \"nodeid\": 18, \"depth\": 4, \"split\": \"f3\", " + + "\"split_condition\": 89.6999969, \"yes\": 35, \"no\": 36, \"missing\": 35, \"children\": [\n" + + " { \"nodeid\": 35, \"depth\": 5, \"split\": \"f1\", " + + "\"split_condition\": 47.3199997, \"yes\": 65, \"no\": 66, \"missing\": 65, \"children\": [\n" + " { \"nodeid\": 65, \"leaf\": 9.94041252 },\n" + " { \"nodeid\": 66, \"leaf\": 1.25593567 }\n" + " ]},\n" + - " { \"nodeid\": 36, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 41.5400009, \"yes\": 67, \"no\": 68, \"missing\": 67, \"children\": [\n" + + " { \"nodeid\": 36, \"depth\": 5, \"split\": \"f1\", " + + "\"split_condition\": 41.5400009, \"yes\": 67, \"no\": 68, \"missing\": 67, \"children\": [\n" + " { \"nodeid\": 67, \"leaf\": 5.41526508 },\n" + " { \"nodeid\": 68, \"leaf\": -1.86407471 }\n" + " ]}\n" + " ]}\n" + " ]}\n" + " ]},\n" + - " { \"nodeid\": 4, \"depth\": 2, \"split\": \"f3\", \"split_condition\": 92.9850006, \"yes\": 9, \"no\": 10, \"missing\": 9, \"children\": [\n" + - " { \"nodeid\": 9, \"depth\": 3, \"split\": \"f1\", \"split_condition\": 48.8600006, \"yes\": 19, \"no\": 20, \"missing\": 19, \"children\": [\n" + - " { \"nodeid\": 19, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 9.78999996, \"yes\": 37, \"no\": 38, \"missing\": 37, \"children\": [\n" + - " { \"nodeid\": 37, \"depth\": 5, \"split\": \"f3\", \"split_condition\": 62.8600006, \"yes\": 69, \"no\": 70, \"missing\": 69, \"children\": [\n" + + " { \"nodeid\": 4, \"depth\": 2, \"split\": \"f3\", " + + "\"split_condition\": 92.9850006, \"yes\": 9, \"no\": 10, \"missing\": 9, \"children\": [\n" + + " { \"nodeid\": 9, \"depth\": 3, \"split\": \"f1\", " + + "\"split_condition\": 48.8600006, \"yes\": 19, \"no\": 20, \"missing\": 19, \"children\": [\n" + + " { \"nodeid\": 19, \"depth\": 4, \"split\": \"f0\", " + + "\"split_condition\": 9.78999996, \"yes\": 37, \"no\": 38, \"missing\": 37, \"children\": [\n" + + " { \"nodeid\": 37, \"depth\": 5, \"split\": \"f3\", " + + "\"split_condition\": 62.8600006, \"yes\": 69, \"no\": 70, \"missing\": 69, \"children\": [\n" + " { \"nodeid\": 69, \"leaf\": 2.36037445 },\n" + " { \"nodeid\": 70, \"leaf\": 7.76275396 }\n" + " ]},\n" + - " { \"nodeid\": 38, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 38.2799988, \"yes\": 71, \"no\": 72, \"missing\": 71, \"children\": [\n" + + " { \"nodeid\": 38, \"depth\": 5, \"split\": \"f1\", " + + "\"split_condition\": 38.2799988, \"yes\": 71, \"no\": 72, \"missing\": 71, \"children\": [\n" + " { \"nodeid\": 71, \"leaf\": 3.42786431 },\n" + " { \"nodeid\": 72, \"leaf\": 6.49917078 }\n" + " ]}\n" + " ]},\n" + - " { \"nodeid\": 20, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 10.3050003, \"yes\": 39, \"no\": 40, \"missing\": 39, \"children\": [\n" + - " { \"nodeid\": 39, \"depth\": 5, \"split\": \"f0\", \"split_condition\": 9.76499939, \"yes\": 73, \"no\": 74, \"missing\": 73, \"children\": [\n" + + " { \"nodeid\": 20, \"depth\": 4, \"split\": \"f0\", " + + "\"split_condition\": 10.3050003, \"yes\": 39, \"no\": 40, \"missing\": 39, \"children\": [\n" + + " { \"nodeid\": 39, \"depth\": 5, \"split\": \"f0\", " + + "\"split_condition\": 9.76499939, \"yes\": 73, \"no\": 74, \"missing\": 73, \"children\": [\n" + " { \"nodeid\": 73, \"leaf\": 1.2359314 },\n" + " { \"nodeid\": 74, \"leaf\": 0.0513916016 }\n" + " ]},\n" + " { \"nodeid\": 40, \"leaf\": -2.12875366 }\n" + " ]}\n" + " ]},\n" + - " { \"nodeid\": 10, \"depth\": 3, \"split\": \"f1\", \"split_condition\": 41.7750015, \"yes\": 21, \"no\": 22, \"missing\": 21, \"children\": [\n" + - " { \"nodeid\": 21, \"depth\": 4, \"split\": \"f1\", \"split_condition\": 41.5350037, \"yes\": 41, \"no\": 42, \"missing\": 41, \"children\": [\n" + - " { \"nodeid\": 41, \"depth\": 5, \"split\": \"f2\", \"split_condition\": 1017.32001, \"yes\": 75, \"no\": 76, \"missing\": 75, \"children\": [\n" + + " { \"nodeid\": 10, \"depth\": 3, \"split\": \"f1\", " + + "\"split_condition\": 41.7750015, \"yes\": 21, \"no\": 22, \"missing\": 21, \"children\": [\n" + + " { \"nodeid\": 21, \"depth\": 4, \"split\": \"f1\", " + + "\"split_condition\": 41.5350037, \"yes\": 41, \"no\": 42, \"missing\": 41, \"children\": [\n" + + " { \"nodeid\": 41, \"depth\": 5, \"split\": \"f2\", " + + "\"split_condition\": 1017.32001, \"yes\": 75, \"no\": 76, \"missing\": 75, \"children\": [\n" + " { \"nodeid\": 75, \"leaf\": 3.39938235 },\n" + " { \"nodeid\": 76, \"leaf\": -0.816912234 }\n" + " ]},\n" + - " { \"nodeid\": 42, \"depth\": 5, \"split\": \"f2\", \"split_condition\": 1021.21497, \"yes\": 77, \"no\": 78, \"missing\": 77, \"children\": [\n" + + " { \"nodeid\": 42, \"depth\": 5, \"split\": \"f2\", " + + "\"split_condition\": 1021.21497, \"yes\": 77, \"no\": 78, \"missing\": 77, \"children\": [\n" + " { \"nodeid\": 77, \"leaf\": -6.331038 },\n" + " { \"nodeid\": 78, \"leaf\": -1.37406921 }\n" + " ]}\n" + " ]},\n" + - " { \"nodeid\": 22, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 9.25, \"yes\": 43, \"no\": 44, \"missing\": 43, \"children\": [\n" + + " { \"nodeid\": 22, \"depth\": 4, \"split\": \"f0\", " + + "\"split_condition\": 9.25, \"yes\": 43, \"no\": 44, \"missing\": 43, \"children\": [\n" + " { \"nodeid\": 43, \"leaf\": 7.57445431 },\n" + - " { \"nodeid\": 44, \"depth\": 5, \"split\": \"f3\", \"split_condition\": 99.7200012, \"yes\": 79, \"no\": 80, \"missing\": 79, \"children\": [\n" + + " { \"nodeid\": 44, \"depth\": 5, \"split\": \"f3\", " + + "\"split_condition\": 99.7200012, \"yes\": 79, \"no\": 80, \"missing\": 79, \"children\": [\n" + " { \"nodeid\": 79, \"leaf\": 3.53175259 },\n" + " { \"nodeid\": 80, \"leaf\": 0.886398315 }\n" + " ]}\n" + @@ -131,87 +168,116 @@ public void testParse() throws IOException { " ]}\n" + " ]}\n" + " ]},\n" + - " { \"nodeid\": 2, \"depth\": 1, \"split\": \"f1\", \"split_condition\": 66.1500015, \"yes\": 5, \"no\": 6, \"missing\": 5, \"children\": [\n" + - " { \"nodeid\": 5, \"depth\": 2, \"split\": \"f0\", \"split_condition\": 18.3850002, \"yes\": 11, \"no\": 12, \"missing\": 11, \"children\": [\n" + - " { \"nodeid\": 11, \"depth\": 3, \"split\": \"f0\", \"split_condition\": 14.4449997, \"yes\": 23, \"no\": 24, \"missing\": 23, \"children\": [\n" + - " { \"nodeid\": 23, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 12.6049995, \"yes\": 45, \"no\": 46, \"missing\": 45, \"children\": [\n" + - " { \"nodeid\": 45, \"depth\": 5, \"split\": \"f3\", \"split_condition\": 88.3000031, \"yes\": 81, \"no\": 82, \"missing\": 81, \"children\": [\n" + + " { \"nodeid\": 2, \"depth\": 1, \"split\": \"f1\", " + + "\"split_condition\": 66.1500015, \"yes\": 5, \"no\": 6, \"missing\": 5, \"children\": [\n" + + " { \"nodeid\": 5, \"depth\": 2, \"split\": \"f0\", " + + "\"split_condition\": 18.3850002, \"yes\": 11, \"no\": 12, \"missing\": 11, \"children\": [\n" + + " { \"nodeid\": 11, \"depth\": 3, \"split\": \"f0\", " + + "\"split_condition\": 14.4449997, \"yes\": 23, \"no\": 24, \"missing\": 23, \"children\": [\n" + + " { \"nodeid\": 23, \"depth\": 4, \"split\": \"f0\", " + + "\"split_condition\": 12.6049995, \"yes\": 45, \"no\": 46, \"missing\": 45, \"children\": [\n" + + " { \"nodeid\": 45, \"depth\": 5, \"split\": \"f3\", " + + "\"split_condition\": 88.3000031, \"yes\": 81, \"no\": 82, \"missing\": 81, \"children\": [\n" + " { \"nodeid\": 81, \"leaf\": 2.6010797 },\n" + " { \"nodeid\": 82, \"leaf\": -0.530121922 }\n" + " ]},\n" + - " { \"nodeid\": 46, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 40.9049988, \"yes\": 83, \"no\": 84, \"missing\": 83, \"children\": [\n" + + " { \"nodeid\": 46, \"depth\": 5, \"split\": \"f1\", " + + "\"split_condition\": 40.9049988, \"yes\": 83, \"no\": 84, \"missing\": 83, \"children\": [\n" + " { \"nodeid\": 83, \"leaf\": -1.30897391 },\n" + " { \"nodeid\": 84, \"leaf\": -3.96863937 }\n" + " ]}\n" + " ]},\n" + - " { \"nodeid\": 24, \"depth\": 4, \"split\": \"f1\", \"split_condition\": 45.0049973, \"yes\": 47, \"no\": 48, \"missing\": 47, \"children\": [\n" + - " { \"nodeid\": 47, \"depth\": 5, \"split\": \"f0\", \"split_condition\": 15.4849997, \"yes\": 85, \"no\": 86, \"missing\": 85, \"children\": [\n" + + " { \"nodeid\": 24, \"depth\": 4, \"split\": \"f1\", " + + "\"split_condition\": 45.0049973, \"yes\": 47, \"no\": 48, \"missing\": 47, \"children\": [\n" + + " { \"nodeid\": 47, \"depth\": 5, \"split\": \"f0\", " + + "\"split_condition\": 15.4849997, \"yes\": 85, \"no\": 86, \"missing\": 85, \"children\": [\n" + " { \"nodeid\": 85, \"leaf\": -5.40670681 },\n" + " { \"nodeid\": 86, \"leaf\": -8.27983284 }\n" + " ]},\n" + - " { \"nodeid\": 48, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 56.6899986, \"yes\": 87, \"no\": 88, \"missing\": 87, \"children\": [\n" + + " { \"nodeid\": 48, \"depth\": 5, \"split\": \"f1\", " + + "\"split_condition\": 56.6899986, \"yes\": 87, \"no\": 88, \"missing\": 87, \"children\": [\n" + " { \"nodeid\": 87, \"leaf\": -12.0328913 },\n" + " { \"nodeid\": 88, \"leaf\": -17.0602112 }\n" + " ]}\n" + " ]}\n" + " ]},\n" + - " { \"nodeid\": 12, \"depth\": 3, \"split\": \"f0\", \"split_condition\": 21.5450001, \"yes\": 25, \"no\": 26, \"missing\": 25, \"children\": [\n" + - " { \"nodeid\": 25, \"depth\": 4, \"split\": \"f1\", \"split_condition\": 45.7399979, \"yes\": 49, \"no\": 50, \"missing\": 49, \"children\": [\n" + - " { \"nodeid\": 49, \"depth\": 5, \"split\": \"f3\", \"split_condition\": 60.5650024, \"yes\": 89, \"no\": 90, \"missing\": 89, \"children\": [\n" + + " { \"nodeid\": 12, \"depth\": 3, \"split\": \"f0\", " + + "\"split_condition\": 21.5450001, \"yes\": 25, \"no\": 26, \"missing\": 25, \"children\": [\n" + + " { \"nodeid\": 25, \"depth\": 4, \"split\": \"f1\", " + + "\"split_condition\": 45.7399979, \"yes\": 49, \"no\": 50, \"missing\": 49, \"children\": [\n" + + " { \"nodeid\": 49, \"depth\": 5, \"split\": \"f3\", " + + "\"split_condition\": 60.5650024, \"yes\": 89, \"no\": 90, \"missing\": 89, \"children\": [\n" + " { \"nodeid\": 89, \"leaf\": 18.9551449 },\n" + " { \"nodeid\": 90, \"leaf\": 13.8164215 }\n" + " ]},\n" + - " { \"nodeid\": 50, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 58.1699982, \"yes\": 91, \"no\": 92, \"missing\": 91, \"children\": [\n" + + " { \"nodeid\": 50, \"depth\": 5, \"split\": \"f1\", " + + "\"split_condition\": 58.1699982, \"yes\": 91, \"no\": 92, \"missing\": 91, \"children\": [\n" + " { \"nodeid\": 91, \"leaf\": 10.819478 },\n" + " { \"nodeid\": 92, \"leaf\": 7.62072706 }\n" + " ]}\n" + " ]},\n" + - " { \"nodeid\": 26, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 26.4650002, \"yes\": 51, \"no\": 52, \"missing\": 51, \"children\": [\n" + - " { \"nodeid\": 51, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 55.7350006, \"yes\": 93, \"no\": 94, \"missing\": 93, \"children\": [\n" + + " { \"nodeid\": 26, \"depth\": 4, \"split\": \"f0\", " + + "\"split_condition\": 26.4650002, \"yes\": 51, \"no\": 52, \"missing\": 51, \"children\": [\n" + + " { \"nodeid\": 51, \"depth\": 5, \"split\": \"f1\", " + + "\"split_condition\": 55.7350006, \"yes\": 93, \"no\": 94, \"missing\": 93, \"children\": [\n" + " { \"nodeid\": 93, \"leaf\": 5.94795942 },\n" + " { \"nodeid\": 94, \"leaf\": 2.39619088 }\n" + " ]},\n" + - " { \"nodeid\": 52, \"depth\": 5, \"split\": \"f2\", \"split_condition\": 1007.45496, \"yes\": 95, \"no\": 96, \"missing\": 95, \"children\": [\n" + + " { \"nodeid\": 52, \"depth\": 5, \"split\": \"f2\", " + + "\"split_condition\": 1007.45496, \"yes\": 95, \"no\": 96, \"missing\": 95, \"children\": [\n" + " { \"nodeid\": 95, \"leaf\": -5.02249479 },\n" + " { \"nodeid\": 96, \"leaf\": -0.807426214 }\n" + " ]}\n" + " ]}\n" + " ]}\n" + " ]},\n" + - " { \"nodeid\": 6, \"depth\": 2, \"split\": \"f0\", \"split_condition\": 25.2750015, \"yes\": 13, \"no\": 14, \"missing\": 13, \"children\": [\n" + - " { \"nodeid\": 13, \"depth\": 3, \"split\": \"f0\", \"split_condition\": 18.3549995, \"yes\": 27, \"no\": 28, \"missing\": 27, \"children\": [\n" + - " { \"nodeid\": 27, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 14.4949999, \"yes\": 53, \"no\": 54, \"missing\": 53, \"children\": [\n" + + " { \"nodeid\": 6, \"depth\": 2, \"split\": \"f0\", " + + "\"split_condition\": 25.2750015, \"yes\": 13, \"no\": 14, \"missing\": 13, \"children\": [\n" + + " { \"nodeid\": 13, \"depth\": 3, \"split\": \"f0\", " + + "\"split_condition\": 18.3549995, \"yes\": 27, \"no\": 28, \"missing\": 27, \"children\": [\n" + + " { \"nodeid\": 27, \"depth\": 4, \"split\": \"f0\", " + + "\"split_condition\": 14.4949999, \"yes\": 53, \"no\": 54, \"missing\": 53, \"children\": [\n" + " { \"nodeid\": 53, \"leaf\": -6.55451679 },\n" + " { \"nodeid\": 54, \"leaf\": -16.1782494 }\n" + " ]},\n" + - " { \"nodeid\": 28, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 22.5400009, \"yes\": 55, \"no\": 56, \"missing\": 55, \"children\": [\n" + - " { \"nodeid\": 55, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 69.0449982, \"yes\": 97, \"no\": 98, \"missing\": 97, \"children\": [\n" + + " { \"nodeid\": 28, \"depth\": 4, \"split\": \"f0\", " + + "\"split_condition\": 22.5400009, \"yes\": 55, \"no\": 56, \"missing\": 55, \"children\": [\n" + + " { \"nodeid\": 55, \"depth\": 5, \"split\": \"f1\", " + + "\"split_condition\": 69.0449982, \"yes\": 97, \"no\": 98, \"missing\": 97, \"children\": [\n" + " { \"nodeid\": 97, \"leaf\": 5.74967909 },\n" + " { \"nodeid\": 98, \"leaf\": 0.178715631 }\n" + " ]},\n" + - " { \"nodeid\": 56, \"depth\": 5, \"split\": \"f3\", \"split_condition\": 68.5200043, \"yes\": 99, \"no\": 100, \"missing\": 99, \"children\": [\n" + + " { \"nodeid\": 56, \"depth\": 5, \"split\": \"f3\", " + + "\"split_condition\": 68.5200043, \"yes\": 99, \"no\": 100, \"missing\": 99, \"children\": [\n" + " { \"nodeid\": 99, \"leaf\": 0.774093628 },\n" + " { \"nodeid\": 100, \"leaf\": -3.68317914 }\n" + " ]}\n" + " ]}\n" + " ]},\n" + - " { \"nodeid\": 14, \"depth\": 3, \"split\": \"f2\", \"split_condition\": 1008.78003, \"yes\": 29, \"no\": 30, \"missing\": 29, \"children\": [\n" + - " { \"nodeid\": 29, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 26.3549995, \"yes\": 57, \"no\": 58, \"missing\": 57, \"children\": [\n" + - " { \"nodeid\": 57, \"depth\": 5, \"split\": \"f3\", \"split_condition\": 75.3950043, \"yes\": 101, \"no\": 102, \"missing\": 101, \"children\": [\n" + + " { \"nodeid\": 14, \"depth\": 3, \"split\": \"f2\", " + + "\"split_condition\": 1008.78003, \"yes\": 29, \"no\": 30, \"missing\": 29, \"children\": [\n" + + " { \"nodeid\": 29, \"depth\": 4, \"split\": \"f0\", " + + "\"split_condition\": 26.3549995, \"yes\": 57, \"no\": 58, \"missing\": 57, \"children\": [\n" + + " { \"nodeid\": 57, \"depth\": 5, \"split\": \"f3\", " + + "\"split_condition\": 75.3950043, \"yes\": 101, \"no\": 102, \"missing\": 101, \"children\": [\n" + " { \"nodeid\": 101, \"leaf\": -3.79112887 },\n" + " { \"nodeid\": 102, \"leaf\": -7.66311884 }\n" + " ]},\n" + - " { \"nodeid\": 58, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 73.7200012, \"yes\": 103, \"no\": 104, \"missing\": 103, \"children\": [\n" + + " { \"nodeid\": 58, \"depth\": 5, \"split\": \"f1\", " + + "\"split_condition\": 73.7200012, \"yes\": 103, \"no\": 104, \"missing\": 103, \"children\": [\n" + " { \"nodeid\": 103, \"leaf\": -9.83749485 },\n" + " { \"nodeid\": 104, \"leaf\": -7.75816107 }\n" + " ]}\n" + " ]},\n" + - " { \"nodeid\": 30, \"depth\": 4, \"split\": \"f0\", \"split_condition\": 29.2350006, \"yes\": 59, \"no\": 60, \"missing\": 59, \"children\": [\n" + - " { \"nodeid\": 59, \"depth\": 5, \"split\": \"f3\", \"split_condition\": 71.4450073, \"yes\": 105, \"no\": 106, \"missing\": 105, \"children\": [\n" + + " { \"nodeid\": 30, \"depth\": 4, \"split\": \"f0\", " + + "\"split_condition\": 29.2350006, \"yes\": 59, \"no\": 60, \"missing\": 59, \"children\": [\n" + + " { \"nodeid\": 59, \"depth\": 5, \"split\": \"f3\", " + + "\"split_condition\": 71.4450073, \"yes\": 105, \"no\": 106, \"missing\": 105, \"children\": [\n" + " { \"nodeid\": 105, \"leaf\": -4.53095436 },\n" + " { \"nodeid\": 106, \"leaf\": -7.3271451 }\n" + " ]},\n" + - " { \"nodeid\": 60, \"depth\": 5, \"split\": \"f1\", \"split_condition\": 67.8099976, \"yes\": 107, \"no\": 108, \"missing\": 107, \"children\": [\n" + + " { \"nodeid\": 60, \"depth\": 5, \"split\": \"f1\", " + + "\"split_condition\": 67.8099976, \"yes\": 107, \"no\": 108, \"missing\": 107, \"children\": [\n" + " { \"nodeid\": 107, \"leaf\": -6.15153503 },\n" + " { \"nodeid\": 108, \"leaf\": -8.42399311 }\n" + " ]}\n" + @@ -240,25 +306,44 @@ public void testParse() throws IOException { assertEquals(2, model.numTrees()); -// List missingNodes = model.checkForNull(); -// System.out.println(missingNodes); + List missingNodes = model.checkForNull(); + // TODO we are adding more nodes than are actually used // assertThat(missingNodes, hasSize(0)); - Map doc = new HashMap<>(); doc.put("f0", 10.36); doc.put("f1", 43.67); doc.put("f2", 1012.1); doc.put("f3", 77.04); + List> trace = model.trace(doc); - System.out.println(trace); - logger.info(trace); + assertThat(trace, hasSize(2)); // 2 trees + List tree1Trace = trace.get(0); + assertThat(tree1Trace, hasSize(2)); + assertTreeTrace(tree1Trace); + + List tree2Trace = trace.get(1); + assertThat(tree2Trace, hasSize(5)); + assertTreeTrace(tree2Trace); double prediction = model.predict(doc); - System.out.println(prediction); - logger.info("prediction: " + prediction); - assertThat(prediction, is(greaterThan(450.0))); - assertThat(prediction, is(lessThan(500.0))); + assertEquals(473.0, prediction, 1.0); } } + + private void assertTreeTrace(List treeTrace) { + for (int i=0; i Date: Fri, 30 Aug 2019 14:08:39 +0100 Subject: [PATCH 3/3] Fix over allocating nodes and add missing tests --- .../xpack/ml/inference/tree/Tree.java | 22 +- .../ml/inference/tree/TreeEnsembleModel.java | 5 - .../tree/TreeEnsembleModelTests.java | 120 +++++++++++ .../xpack/ml/inference/tree/TreeTests.java | 194 ++++++++++++++++++ .../tree/xgboost/XgBoostJsonParserTests.java | 6 +- 5 files changed, 327 insertions(+), 20 deletions(-) create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/tree/TreeEnsembleModelTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/tree/TreeTests.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/Tree.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/Tree.java index 506f7983d27be..aeb5e45bf3f13 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/Tree.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/Tree.java @@ -62,8 +62,10 @@ private Double predict(List features, int nodeIndex) { } /** - * Finds null nodes - * @return List of indexes to null nodes + * Finds {@code null} nodes. If constructed properly there should be no {@code null} nodes. + * {@code null} nodes indicates missing leaf or junction nodes + * + * @return List of indexes to the {@code null} nodes */ List missingNodes() { List nullNodeIndices = new ArrayList<>(); @@ -185,8 +187,6 @@ public static TreeBuilder newTreeBuilder() { * @return The created node */ public Node addJunction(int nodeIndex, int featureIndex, boolean isDefaultLeft, double decisionThreshold) { -// assert nodeIndex < nodes.size() : "node index " + nodeIndex + " >= size " + nodes.size(); - int leftChild = numNodes++; int rightChild = numNodes++; nodes.ensureCapacity(nodeIndex +1); @@ -194,13 +194,13 @@ public Node addJunction(int nodeIndex, int featureIndex, boolean isDefaultLeft, nodes.add(null); } - Node node = new Node(leftChild, rightChild, featureIndex, isDefaultLeft, decisionThreshold); nodes.set(nodeIndex, node); + // allocate space for the child nodes - nodes.add(null); - nodes.add(null); -// assert nodes.size() == numNodes : "nodes size " + nodes.size() + " != num nodes " + numNodes; + while (nodes.size() <= rightChild) { + nodes.add(null); + } return node; } @@ -212,15 +212,11 @@ public Node addJunction(int nodeIndex, int featureIndex, boolean isDefaultLeft, * @return this */ public TreeBuilder addLeaf(int nodeIndex, double value) { -// assert nodeIndex < nodes.size(); - for (int i=nodes.size(); i trees; private final Map featureMap; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/tree/TreeEnsembleModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/tree/TreeEnsembleModelTests.java new file mode 100644 index 0000000000000..83a6933cee5cd --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/tree/TreeEnsembleModelTests.java @@ -0,0 +1,120 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.tree; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.not; + +public class TreeEnsembleModelTests extends ESTestCase { + + public void testMergePredictions() { + TreeEnsembleModel emptyModel = TreeEnsembleModel.modelBuilder(Collections.emptyMap()).build(); + + List predictions = Arrays.asList(0.1, 0.2, 0.3); + // prediction is the sum of the component predictions + assertEquals(0.6, emptyModel.mergePredictions(predictions), 0.00001); + } + + public void testDocToFeatureVector() { + Map featureMap = new HashMap<>(); + + featureMap.put("price", 0); + featureMap.put("number_of_wheels", 1); + featureMap.put("phone_charger", 2); + + Map document = new HashMap<>(); + document.put("price", 1000.0); + document.put("number_of_wheels", 4.0); + document.put("phone_charger", 0.0); + document.put("unrelated", 1.0); + + TreeEnsembleModel emptyModel = TreeEnsembleModel.modelBuilder(featureMap).build(); + List featureVector = emptyModel.docToFeatureVector(document); + assertEquals(Arrays.asList(1000.0, 4.0, 0.0), featureVector); + + + featureMap.put("missing_field", 3); + emptyModel = TreeEnsembleModel.modelBuilder(featureMap).build(); + featureVector = emptyModel.docToFeatureVector(document); + assertEquals(Arrays.asList(1000.0, 4.0, 0.0, null), featureVector); + } + + public void testTrace() { + int numFeatures = randomIntBetween(1, 4); + int numTrees = randomIntBetween(1, 5); + TreeEnsembleModel model = buildRandomModel(numFeatures, numTrees); + + Map doc = new HashMap<>(); + doc.put("a", randomDouble()); + doc.put("b", randomDouble()); + doc.put("c", randomDouble()); + doc.put("d", randomDouble()); + List> traces = model.trace(doc); + assertThat(traces, hasSize(numTrees)); + for (var trace : traces) { + assertThat(trace, not(empty())); + } + } + + public void testPredict() { + Map featureMap = new HashMap<>(); + featureMap.put("a", 0); + featureMap.put("b", 1); + + TreeEnsembleModel.ModelBuilder modelBuilder = TreeEnsembleModel.modelBuilder(featureMap); + + // Simple tree + Tree.TreeBuilder treeBuilder1 = Tree.TreeBuilder.newTreeBuilder(); + Tree.Node rootNode = treeBuilder1.addJunction(0, 0, true, 0.5); + treeBuilder1.addLeaf(rootNode.leftChild, 0.1); + treeBuilder1.addLeaf(rootNode.rightChild, 0.2); + + Tree.TreeBuilder treeBuilder2 = Tree.TreeBuilder.newTreeBuilder(); + Tree.Node rootNode2 = treeBuilder2.addJunction(0, 1, true, 0.5); + treeBuilder2.addLeaf(rootNode2.leftChild, 0.1); + treeBuilder2.addLeaf(rootNode2.rightChild, 0.2); + + modelBuilder.addTree(treeBuilder1.build()); + modelBuilder.addTree(treeBuilder2.build()); + + TreeEnsembleModel model = modelBuilder.build(); + + // this doc should result in prediction 0.1 for tree1 and 0.2 for tree2 + // the prediction is the sum of the scores (boosted tree) + Map doc = new HashMap<>(); + doc.put("a", 0.4); + doc.put("b", 0.7); + + double prediction = model.predict(doc); + assertEquals(0.3, prediction, 0.00001); + } + + private TreeEnsembleModel buildRandomModel(int numFeatures, int numTrees) { + Map featureMap = new HashMap<>(); + char fieldName = 'a'; + int index = 0; + for (int i=0; i childNodes = List.of(node.leftChild, node.rightChild); + + for (int i=0; i nextNodes = new ArrayList<>(); + for (int nodeId : childNodes) { + + if (i == depth -2) { + builder.addLeaf(nodeId, randomDecisionThreshold()); + } else { + Tree.Node childNode = builder.addJunction(nodeId, randomFeatureIndex(numFeatures), true, randomDecisionThreshold()); + nextNodes.add(childNode.leftChild); + nextNodes.add(childNode.rightChild); + } + } + + childNodes = nextNodes; + } + + return builder.build(); + } + + static int randomFeatureIndex(int max) { + return randomIntBetween(0, max -1); + } + + static double randomDecisionThreshold() { + return randomDouble(); + } + + public void testPredict() { + // Build a tree with 2 nodes and 3 leaves using 2 features + // The leaves have unique values 0.1, 0.2, 0.3 + Tree.TreeBuilder builder = Tree.TreeBuilder.newTreeBuilder(); + Tree.Node rootNode = builder.addJunction(0, 0, true, 0.5); + builder.addLeaf(rootNode.rightChild, 0.3); + Tree.Node leftChildNode = builder.addJunction(rootNode.leftChild, 1, true, 0.8); + builder.addLeaf(leftChildNode.leftChild, 0.1); + builder.addLeaf(leftChildNode.rightChild, 0.2); + + Tree tree = builder.build(); + + // This feature vector should hit the right child of the root node + List featureVector = Arrays.asList(0.6, 0.0); + assertEquals(0.3, tree.predict(featureVector), 0.00001); + + // This should hit the left child of the left child of the root node + // i.e. it takes the path left, left + featureVector = Arrays.asList(0.3, 0.7); + assertEquals(0.1, tree.predict(featureVector), 0.00001); + + // This should hit the right child of the left child of the root node + // i.e. it takes the path left, right + featureVector = Arrays.asList(0.3, 0.8); + assertEquals(0.2, tree.predict(featureVector), 0.00001); + } + + public void testTrace() { + int numFeatures = randomIntBetween(1, 6); + int depth = 6; + Tree tree = buildRandomTree(numFeatures, depth); + + List features = new ArrayList<>(numFeatures); + for (int i=0; i trace = tree.trace(features); + assertThat(trace, hasSize(depth)); + for (int i=0; i features = List.of(0.1); + assertEquals(leftChild, node.compare(features)); + + features = List.of(0.9); + assertEquals(rightChild, node.compare(features)); + } + + public void testCompare_nonDefaultOperator() { + int leftChild = 1; + int rightChild = 2; + Tree.Node node = new Tree.Node(leftChild, rightChild, 0, true, 0.5, (value, threshold) -> value >= threshold); + + List features = List.of(0.1); + assertEquals(rightChild, node.compare(features)); + features = List.of(0.5); + assertEquals(leftChild, node.compare(features)); + features = List.of(0.9); + assertEquals(leftChild, node.compare(features)); + + node = new Tree.Node(leftChild, rightChild, 0, true, 0.5, (value, threshold) -> value <= threshold); + + features = List.of(0.1); + assertEquals(leftChild, node.compare(features)); + features = List.of(0.5); + assertEquals(leftChild, node.compare(features)); + features = List.of(0.9); + assertEquals(rightChild, node.compare(features)); + } + + public void testCompare_missingFeature() { + int leftChild = 1; + int rightChild = 2; + Tree.Node leftBiasNode = new Tree.Node(leftChild, rightChild, 0, true, 0.5); + List features = new ArrayList<>(); + features.add(null); + assertEquals(leftChild, leftBiasNode.compare(features)); + + Tree.Node rightBiasNode = new Tree.Node(leftChild, rightChild, 0, false, 0.5); + assertEquals(rightChild, rightBiasNode.compare(features)); + } + + public void testIsLeaf() { + Tree.Node leaf = new Tree.Node(0.0); + assertTrue(leaf.isLeaf()); + + Tree.Node node = new Tree.Node(1, 2, 0, false, 0.0); + assertFalse(node.isLeaf()); + } + + public void testMissingNodes() { + Tree.TreeBuilder builder = Tree.TreeBuilder.newTreeBuilder(); + Tree.Node rootNode = builder.addJunction(0, 0, true, randomDecisionThreshold()); + + Tree.Node node2 = builder.addJunction(rootNode.rightChild, 0, false, 0.1); + builder.addLeaf(node2.leftChild, 0.1); + + List missingNodeIndexes = builder.build().missingNodes(); + assertEquals(Arrays.asList(1, 4), missingNodeIndexes); + } +} \ No newline at end of file diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParserTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParserTests.java index 1e60434234527..950b10ca5d5e5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParserTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/tree/xgboost/XgBoostJsonParserTests.java @@ -58,6 +58,9 @@ public void testSimpleParse() throws IOException { double prediction = model.predict(doc); assertEquals(-6.23624468, prediction, 0.0001); + + List missingNodes = model.checkForNull(); + assertThat(missingNodes, hasSize(0)); } } @@ -307,8 +310,7 @@ public void testParse() throws IOException { assertEquals(2, model.numTrees()); List missingNodes = model.checkForNull(); - // TODO we are adding more nodes than are actually used -// assertThat(missingNodes, hasSize(0)); + assertThat(missingNodes, hasSize(0)); Map doc = new HashMap<>(); doc.put("f0", 10.36);