diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/AsyncModel/AsyncModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/AsyncModel/AsyncModel.java new file mode 100644 index 0000000000000..7c748be9ee785 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/AsyncModel/AsyncModel.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.AsyncModel; + +import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.xpack.ml.inference.Model; + +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.function.BiConsumer; + +/** + * A model (implements the inference function) that has its model state loaded + * from an index document via a {@link AsyncModelLoader}. When the AsyncModelLoader + * has fetched the document it will notify this class and subclasses then know how to + * construct the model. + * + * Any ingest documents arriving while waiting for the model state to load must be queued up. + * + * {@link #createModel(GetResponse)} should be implemented in subclasses to read + * the model state from the GetResponse supplied by the loader. + * + * {@link #inferPrivate(IngestDocument, BiConsumer)} does the actual inference. + */ +public abstract class AsyncModel implements Model { + + private final boolean ignoreMissing; + + private volatile boolean isLoaded = false; + private volatile Exception error; + + private final Queue>> documentQueue; + + protected AsyncModel(boolean ignoreMissing) { + this.ignoreMissing = ignoreMissing; + documentQueue = new ConcurrentLinkedDeque<>(); + } + + @Override + public void infer(IngestDocument document, BiConsumer handler) { + if (isLoaded) { + inferPrivate(document, handler); + return; + } + + if (error != null) { + handler.accept(null, error); + return; + } + + // if we have a list of requests waiting to be used then they have to be queued up + queueRequest(document, handler); + } + + /** + * Should be threadsafe + * @param document The ingest document + * @param handler Ingest handler + */ + protected abstract void inferPrivate(IngestDocument document, BiConsumer handler); + + + + void imLoaded(GetResponse getResponse) { + createModel(getResponse); + drainQueuedToInfer(); + isLoaded = true; + } + + void setError(Exception exception) { + drainQueuedToError(); + this.error = exception; + } + + private synchronized void queueRequest(IngestDocument document, BiConsumer handler) { + documentQueue.add(new Tuple<>(document, handler)); + } + + private synchronized void drainQueuedToInfer() { + for (Tuple> request : documentQueue) { + inferPrivate(request.v1(), request.v2()); + } + } + + private synchronized void drainQueuedToError() { + for (Tuple> request : documentQueue) { + request.v2().accept(null, error); + } + } + + public boolean isIgnoreMissing() { + return ignoreMissing; + } + + protected abstract void createModel(GetResponse getResponse); +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/AsyncModel/AsyncModelLoader.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/AsyncModel/AsyncModelLoader.java new file mode 100644 index 0000000000000..43b05ad238e0d --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/AsyncModel/AsyncModelLoader.java @@ -0,0 +1,122 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.AsyncModel; + +import org.apache.log4j.LogManager; +import org.apache.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.ingest.ConfigurationUtils; +import org.elasticsearch.xpack.ml.inference.InferenceProcessor; +import org.elasticsearch.xpack.ml.inference.ModelLoader; + +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +// This class is full of races. +// +// The general idea is that this class can be used to load any type of model where the +// model state has to be fetched from an index. TODO the class is poorly named + +// The load() handles fetching the index document and will return a subclass of AsyncModel (type T) +// and register that object as a listener to be notified once the loading has finished +// or an error occurred. But the load() method can only be called once, if its called +// more than once it should wait for the loading to finish then notify all the listeners. + +public abstract class AsyncModelLoader implements ModelLoader { + + private static final Logger logger = LogManager.getLogger(AsyncModelLoader.class); + + public static final String INDEX = "index"; + + private final Client client; + private final Function modelSupplier; + + private AtomicBoolean loadingFinished = new AtomicBoolean(false); + private volatile GetResponse response; + private volatile Exception loadingException; + private volatile T loadedListener; + + + protected AsyncModelLoader(Client client, Function modelSupplier) { + this.client = client; + this.modelSupplier = modelSupplier; + } + + @Override + public T load(String modelId, String processorTag, boolean ignoreMissing, Map config) { + String index = readIndexName(processorTag, config); + String documentId = documentId(modelId, config); + + // TODO if this method is called twice loadedListener will be overwritten. + loadedListener = modelSupplier.apply(ignoreMissing); + load(documentId, index); + return loadedListener; + } + + @Override + public void consumeConfiguration(String processorTag, Map config) { + readIndexName(processorTag, config); + } + + /** + * Read the name of the index to get the model state from. + * The default is to read the string value of object {@value #INDEX}. + * + * @param processorTag Tag + * @param config The processor config + * @return The name of the index containing the model + */ + protected String readIndexName(String processorTag, Map config) { + return ConfigurationUtils.readStringProperty(InferenceProcessor.TYPE, processorTag, config, INDEX); + } + + /** + * Construct the document Id used in the GET request. + * This function is intended to be overridden, this implementation simply returns {@code modelId} + * + * @param modelId The model Id + * @param config The processor config + * @return The document Id + */ + protected String documentId(String modelId, Map config) { + return modelId; + } + + private void load(String id, String index) { + ActionListener listener = ActionListener.wrap(this::setResponse, this::setLoadingException); + + loadingFinished.compareAndSet(false, true); + client.prepareGet(index, null, id).execute(listener); + } + + private synchronized void setResponse(GetResponse response) { + + this.response = response; + loadingFinished.set(true); + if (loadedListener != null) { + loadedListener.imLoaded(response); + } + } + + private void setLoadingException(Exception e) { + this.loadingException = e; + loadingFinished.set(true); + + } + + public GetResponse getGetResponse() { + return response; + } + + public Exception getLoadingException() { + return loadingException; + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/sillymodel/SillyModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/sillymodel/SillyModel.java index c10a291ba3766..11bb17dbd09d6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/sillymodel/SillyModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/sillymodel/SillyModel.java @@ -6,9 +6,10 @@ package org.elasticsearch.xpack.ml.inference.sillymodel; +import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.common.Randomness; import org.elasticsearch.ingest.IngestDocument; -import org.elasticsearch.xpack.ml.inference.Model; +import org.elasticsearch.xpack.ml.inference.AsyncModel.AsyncModel; import java.util.Random; import java.util.function.BiConsumer; @@ -16,18 +17,25 @@ /** * Trivial model whose only purpose is to aid code design */ -public class SillyModel implements Model { +public class SillyModel extends AsyncModel { private static final String TARGET_FIELD = "hotdog_or_not"; private final Random random; - public SillyModel() { + public SillyModel(boolean ignoreMissing) { + super(ignoreMissing); random = Randomness.get(); } - public void infer(IngestDocument document, BiConsumer handler) { + @Override + public void inferPrivate(IngestDocument document, BiConsumer handler) { document.setFieldValue(TARGET_FIELD, random.nextBoolean() ? "hotdog" : "not"); handler.accept(document, null); } + + @Override + protected void createModel(GetResponse getResponse) { + + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/sillymodel/SillyModelLoader.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/sillymodel/SillyModelLoader.java index 9cb9e575ae53a..8dc89908c3be8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/sillymodel/SillyModelLoader.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/sillymodel/SillyModelLoader.java @@ -6,68 +6,13 @@ package org.elasticsearch.xpack.ml.inference.sillymodel; -import org.elasticsearch.ResourceNotFoundException; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.LatchedActionListener; import org.elasticsearch.client.Client; -import org.elasticsearch.ingest.ConfigurationUtils; -import org.elasticsearch.xpack.ml.inference.InferenceProcessor; -import org.elasticsearch.xpack.ml.inference.Model; -import org.elasticsearch.xpack.ml.inference.ModelLoader; +import org.elasticsearch.xpack.ml.inference.AsyncModel.AsyncModelLoader; -import java.util.Map; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicReference; - -public class SillyModelLoader implements ModelLoader { - - public static final String MODEL_TYPE = "model_stored_in_index"; - - private static String INDEX = "index"; - - private final Client client; +public class SillyModelLoader extends AsyncModelLoader { + public static final String MODEL_TYPE = "silly"; public SillyModelLoader(Client client) { - this.client = client; - } - - @Override - public Model load(String modelId, String processorTag, boolean ignoreMissing, Map config) throws Exception { - CountDownLatch latch = new CountDownLatch(1); - AtomicReference model = new AtomicReference<>(); - AtomicReference exception = new AtomicReference<>(); - - LatchedActionListener listener = new LatchedActionListener<>( - ActionListener.wrap(model::set, exception::set), latch - ); - - String index = ConfigurationUtils.readStringProperty(InferenceProcessor.TYPE, processorTag, config, INDEX); - - load(modelId, index, listener); - latch.await(); - if (exception.get() != null) { - throw exception.get(); - } - - return model.get(); - } - - @Override - public void consumeConfiguration(String processorTag, Map config) { - ConfigurationUtils.readStringProperty(InferenceProcessor.TYPE, processorTag, config, INDEX); - } - - - private void load(String id, String index, ActionListener listener) { - client.prepareGet(index, null, id).execute(ActionListener.wrap( - response -> { - if (response.isExists()) { - listener.onResponse(new SillyModel()); - } else { - listener.onFailure(new ResourceNotFoundException("missing model [{}], [{}]", id, index)); - } - }, - listener::onFailure - )); + super(client, SillyModel::new); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/AsyncModel/AsyncModelLoaderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/AsyncModel/AsyncModelLoaderTests.java new file mode 100644 index 0000000000000..1f09e2d12da1b --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/AsyncModel/AsyncModelLoaderTests.java @@ -0,0 +1,175 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.AsyncModel; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.get.GetRequestBuilder; +import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.index.VersionType; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.junit.After; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BiConsumer; +import java.util.function.Function; + +import static org.hamcrest.Matchers.hasSize; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class AsyncModelLoaderTests extends ESTestCase { + + private static final String PROCESSOR_TAG = "model_processor"; + + private TestThreadPool threadPool; + private int sequentialDocumentId = 0; + + private class ConcreteLoader extends AsyncModelLoader { + + ConcreteLoader(Client client, Function modelSupplier) { + super(client, modelSupplier); + } + } + + @Before + public void setUp() { + threadPool = new TestThreadPool("AsyncModelLoaderTests"); + } + + @After + public void tearDown() { + terminate(threadPool); + } + + private class ConcreteModel extends AsyncModel { + + private List>> requests = new ArrayList<>(); + private AtomicLong createModelCount = new AtomicLong(); + + private ConcreteModel(boolean ignoreMissing) { + super(ignoreMissing); + } + + @Override + protected void inferPrivate(IngestDocument document, BiConsumer handler) { + requests.add(new Tuple<>(document, handler)); + } + + @Override + protected void createModel(GetResponse getResponse) { + createModelCount.incrementAndGet(); + } + } + + public void testSlowLoad() throws Exception { + + CountDownLatch latch = new CountDownLatch(1); + + Client client = mockLatchedClient(mock(GetResponse.class), latch); + ConcreteLoader loader = new ConcreteLoader(client, ConcreteModel::new); + + ConcreteModel loadingModel = loader.load("foo", PROCESSOR_TAG, false, config()); + + Handler handler = new Handler(); + + int preloadedCount = randomIntBetween(0, 10); + for (int i=0; i { + assertThat(handler.docs, hasSize(preloadedCount + 10)); + }); + } + + @SuppressWarnings("unchecked") + private Client mockClient(GetResponse response, TimeValue delay) { + + Client client = mock(Client.class); + GetRequestBuilder requestBuilder = mock(GetRequestBuilder.class); + when(client.prepareGet(any(), any(), any())).thenReturn(requestBuilder); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[0]; + threadPool.schedule(() -> listener.onResponse(response), delay, "same"); + return null; + }).when(requestBuilder).execute(any()); + + return client; + } + + @SuppressWarnings("unchecked") + private Client mockLatchedClient(GetResponse response, CountDownLatch latch) { + + Client client = mock(Client.class); + GetRequestBuilder requestBuilder = mock(GetRequestBuilder.class); + when(client.prepareGet(any(), any(), any())).thenReturn(requestBuilder); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[0]; + threadPool.generic().submit(() -> { + try { + latch.await(); + } catch (InterruptedException e) { + listener.onFailure(e); + } + listener.onResponse(response); + }); + + return null; + }).when(requestBuilder).execute(any()); + + return client; + } + + private IngestDocument ingestDocument() { + return new IngestDocument("index", "type", "id_" + sequentialDocumentId++, "route", + 0L, VersionType.INTERNAL, Collections.emptyMap()); + } + + private Map config() { + Map mutableMap = new HashMap<>(); + mutableMap.put(AsyncModelLoader.INDEX, "some_index"); + return mutableMap; + } + + private class Handler { + List docs = new ArrayList<>(); + List errors = new ArrayList<>(); + + void handle(IngestDocument doc, Exception e) { + if (doc != null) { + docs.add(doc); + } else { + errors.add(e); + } + } + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/AsyncModel/AsyncModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/AsyncModel/AsyncModelTests.java new file mode 100644 index 0000000000000..7000570760e31 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/AsyncModel/AsyncModelTests.java @@ -0,0 +1,12 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.AsyncModel; + +import org.elasticsearch.test.ESTestCase; + +public class AsyncModelTests extends ESTestCase { +}