diff --git a/src/deepsparse/transformers/helpers.py b/src/deepsparse/transformers/helpers.py index d5fc5ed438..25a179c5d0 100644 --- a/src/deepsparse/transformers/helpers.py +++ b/src/deepsparse/transformers/helpers.py @@ -30,11 +30,7 @@ from onnx import ModelProto from deepsparse.log import get_main_logger -from deepsparse.utils.onnx import ( - _MODEL_DIR_ONNX_NAME, - model_to_path, - truncate_onnx_model, -) +from deepsparse.utils.onnx import MODEL_ONNX_NAME, model_to_path, truncate_onnx_model from sparsezoo.utils import save_onnx @@ -55,6 +51,7 @@ def setup_transformers_pipeline( sequence_length: int, tokenizer_padding_side: str = "left", engine_kwargs: Optional[Dict] = None, + onnx_model_name: Optional[str] = None, ) -> Tuple[ str, transformers.PretrainedConfig, transformers.PreTrainedTokenizer, Dict[str, Any] ]: @@ -66,30 +63,27 @@ def setup_transformers_pipeline( :param tokenizer_padding_side: The side to pad on for the tokenizer, either "left" or "right" :param engine_kwargs: The kwargs to pass to the engine + :param onnx_model_name: The name of the onnx model to be loaded. + If not specified, defaults are used (see setup_onnx_file_path) :return The model path, config, tokenizer, and engine kwargs """ - model_path, config, tokenizer = fetch_onnx_file_path(model_path, sequence_length) + model_path, config, tokenizer = setup_onnx_file_path( + model_path, sequence_length, onnx_model_name + ) tokenizer.padding_side = tokenizer_padding_side if not tokenizer.pad_token: tokenizer.pad_token = tokenizer.eos_token engine_kwargs = engine_kwargs or {} - if engine_kwargs.get("model_path"): - raise ValueError( - "The engine kwargs already specify " - f"a model path: {engine_kwargs['model_path']}, " - f"but a model path was also provided: {model_path}. " - "Please only provide one." - ) engine_kwargs["model_path"] = model_path return model_path, config, tokenizer, engine_kwargs -def fetch_onnx_file_path( +def setup_onnx_file_path( model_path: str, sequence_length: int, - task: Optional[str] = None, + onnx_model_name: Optional[str] = None, ) -> Tuple[str, transformers.PretrainedConfig, transformers.PreTrainedTokenizer]: """ Parses ONNX model from the `model_path` provided. It additionally @@ -97,17 +91,18 @@ def fetch_onnx_file_path( derived from the `model_path` provided. :param model_path: path to the model to be parsed :param sequence_length: maximum sequence length of the model + :param onnx_model_name: optionally, the precise name of the ONNX model + of interest may be specified. If not specified, the default ONNX model + name will be used (refer to `get_deployment_path` for details) :return: file path to the processed ONNX file for the engine to compile """ - deployment_path, onnx_path = get_deployment_path(model_path) + deployment_path, onnx_path = get_deployment_path(model_path, onnx_model_name) hf_logger = logging.getLogger("transformers") hf_logger_level = hf_logger.level hf_logger.setLevel(logging.ERROR) - config = transformers.PretrainedConfig.from_pretrained( - deployment_path, finetuning_task=task - ) + config = transformers.PretrainedConfig.from_pretrained(deployment_path) hf_logger.setLevel(hf_logger_level) trust_remote_code = False @@ -126,7 +121,9 @@ def fetch_onnx_file_path( return onnx_path, config, tokenizer -def get_deployment_path(model_path: str) -> Tuple[str, str]: +def get_deployment_path( + model_path: str, onnx_model_name: Optional[str] = None +) -> Tuple[str, str]: """ Returns the path to the deployment directory for the given model path and the path to the mandatory @@ -135,9 +132,13 @@ def get_deployment_path(model_path: str) -> Tuple[str, str]: for running the transformers model in the deepsparse pipeline :param model_path: path to model directory, sparsezoo stub, or ONNX file + :param onnx_model_name: optionally, the precise name of the ONNX model + of interest may be specified. If not specified, the default ONNX model + name will be used. :return: path to the deployment directory and path to the ONNX file inside the deployment directory """ + onnx_model_name = onnx_model_name or MODEL_ONNX_NAME if os.path.isfile(model_path): # return the parent directory of the ONNX file return os.path.dirname(model_path), model_path @@ -145,13 +146,13 @@ def get_deployment_path(model_path: str) -> Tuple[str, str]: if os.path.isdir(model_path): model_files = os.listdir(model_path) - if _MODEL_DIR_ONNX_NAME not in model_files: + if onnx_model_name not in model_files: raise ValueError( - f"{_MODEL_DIR_ONNX_NAME} not found in transformers model directory " + f"{onnx_model_name} not found in transformers model directory " f"{model_path}. Be sure that an export of the model is written to " - f"{os.path.join(model_path, _MODEL_DIR_ONNX_NAME)}" + f"{os.path.join(model_path, onnx_model_name)}" ) - return model_path, os.path.join(model_path, _MODEL_DIR_ONNX_NAME) + return model_path, os.path.join(model_path, onnx_model_name) elif model_path.startswith("zoo:") or model_path.startswith("hf:"): onnx_model_path = model_to_path(model_path) diff --git a/src/deepsparse/utils/onnx.py b/src/deepsparse/utils/onnx.py index ae0913ffd7..e4b41f3286 100644 --- a/src/deepsparse/utils/onnx.py +++ b/src/deepsparse/utils/onnx.py @@ -56,12 +56,12 @@ "has_model_kv_cache", "CACHE_INPUT_PREFIX", "CACHE_OUTPUT_PREFIX", - "_MODEL_DIR_ONNX_NAME", + "MODEL_ONNX_NAME", ] _LOGGER = logging.getLogger(__name__) -_MODEL_DIR_ONNX_NAME = "model.onnx" +MODEL_ONNX_NAME = "model.onnx" CACHE_INPUT_PREFIX = "past_key_values" CACHE_OUTPUT_PREFIX = "present" @@ -132,7 +132,7 @@ def model_to_path(model: Union[str, Model, File]) -> str: model.deployment.path # default to the main onnx file for the model - model = model.deployment.get_file(_MODEL_DIR_ONNX_NAME).path + model = model.deployment.get_file(MODEL_ONNX_NAME).path elif File is not object and isinstance(model, File): # get the downloaded_path -- will auto download if not on local system @@ -143,10 +143,10 @@ def model_to_path(model: Union[str, Model, File]) -> str: from huggingface_hub import snapshot_download deployment_path = snapshot_download(repo_id=model.replace("hf:", "", 1)) - onnx_path = os.path.join(deployment_path, _MODEL_DIR_ONNX_NAME) + onnx_path = os.path.join(deployment_path, MODEL_ONNX_NAME) if not os.path.isfile(onnx_path): raise ValueError( - f"Could not find the ONNX model file '{_MODEL_DIR_ONNX_NAME}' in the " + f"Could not find the ONNX model file '{MODEL_ONNX_NAME}' in the " f"Hugging Face Hub repository located at {deployment_path}. Please " f"ensure the model has been correctly exported to ONNX format and " f"exists in the repository." @@ -161,7 +161,7 @@ def model_to_path(model: Union[str, Model, File]) -> str: model_path = Path(model) if model_path.is_dir(): - return str(model_path / _MODEL_DIR_ONNX_NAME) + return str(model_path / MODEL_ONNX_NAME) return model diff --git a/src/deepsparse/v2/__init__.py b/src/deepsparse/v2/__init__.py new file mode 100644 index 0000000000..29fcd4126c --- /dev/null +++ b/src/deepsparse/v2/__init__.py @@ -0,0 +1,21 @@ +# flake8: noqa + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .operators import * +from .pipeline import * +from .routers import * +from .schedulers import * +from .utils import * diff --git a/src/deepsparse/v2/image_classification/__init__.py b/src/deepsparse/v2/image_classification/__init__.py new file mode 100644 index 0000000000..8668227df7 --- /dev/null +++ b/src/deepsparse/v2/image_classification/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +from .postprocess_operator import * +from .preprocess_operator import * + + +from .pipeline import * # isort:skip diff --git a/src/deepsparse/v2/image_classification/pipeline.py b/src/deepsparse/v2/image_classification/pipeline.py new file mode 100644 index 0000000000..3d7887a701 --- /dev/null +++ b/src/deepsparse/v2/image_classification/pipeline.py @@ -0,0 +1,62 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import warnings +from typing import Dict, Optional, Tuple, Union + +from deepsparse.v2.image_classification.postprocess_operator import ( + ImageClassificationPostProcess, +) +from deepsparse.v2.image_classification.preprocess_operator import ( + ImageClassificationPreProcess, +) +from deepsparse.v2.operators.engine_operator import EngineOperator +from deepsparse.v2.pipeline import Pipeline +from deepsparse.v2.routers.router import LinearRouter +from deepsparse.v2.schedulers.scheduler import OperatorScheduler + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["ImageClassificationPipeline"] + + +class ImageClassificationPipeline(Pipeline): + def __init__( + self, + model_path: str, + engine_kwargs: Optional[Dict] = None, + class_names: Union[None, str, Dict[str, str]] = None, + image_size: Optional[Tuple[int]] = None, + top_k: int = 1, + ): + if not engine_kwargs: + engine_kwargs = {} + engine_kwargs["model_path"] = model_path + elif engine_kwargs.get("model_path") != model_path: + warnings.warn(f"Updating engine_kwargs to include {model_path}") + + engine = EngineOperator(**engine_kwargs) + preproces = ImageClassificationPreProcess( + model_path=engine.model_path, image_size=image_size + ) + postprocess = ImageClassificationPostProcess( + top_k=top_k, class_names=class_names + ) + + ops = [preproces, engine, postprocess] + router = LinearRouter(end_route=len(ops)) + scheduler = [OperatorScheduler()] + super().__init__(ops=ops, router=router, schedulers=scheduler) diff --git a/src/deepsparse/v2/image_classification/postprocess_operator.py b/src/deepsparse/v2/image_classification/postprocess_operator.py new file mode 100644 index 0000000000..9231113368 --- /dev/null +++ b/src/deepsparse/v2/image_classification/postprocess_operator.py @@ -0,0 +1,81 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Dict, List, Union + +import numpy +from pydantic import BaseModel, Field + +from deepsparse.v2.operators import Operator + + +class ImageClassificationOutput(BaseModel): + """ + Output model for image classification + """ + + labels: List[Union[int, str, List[int], List[str]]] = Field( + description="List of labels, one for each prediction" + ) + scores: List[Union[float, List[float]]] = Field( + description="List of scores, one for each prediction" + ) + + +__all__ = ["ImageClassificationPostProcess"] + + +class ImageClassificationPostProcess(Operator): + """ + Image Classification post-processing Operator. This Operator is responsible for + processing outputs from the engine and returning the classification results to + the user, using the ImageClassifcationOutput structure. + """ + + input_schema = None + output_schema = ImageClassificationOutput + + def __init__( + self, top_k: int = 1, class_names: Union[None, str, Dict[str, str]] = None + ): + self.top_k = top_k + if isinstance(class_names, str) and class_names.endswith(".json"): + self._class_names = json.load(open(class_names)) + elif isinstance(class_names, dict): + self._class_names = class_names + else: + self._class_names = None + + def run(self, inp: "EngineOperatorOutputs", **kwargs) -> Dict: # noqa: F821 + labels, scores = [], [] + inp = inp.engine_outputs + for prediction_batch in inp[0]: + label = (-prediction_batch).argsort()[: self.top_k] + score = prediction_batch[label] + labels.append(label) + scores.append(score.tolist()) + + if self._class_names is not None: + labels = numpy.vectorize(self._class_names.__getitem__)(labels) + labels = labels.tolist() + + if isinstance(labels[0], numpy.ndarray): + labels = [label.tolist() for label in labels] + + if len(labels) == 1: + labels = labels[0] + scores = scores[0] + + return {"scores": scores, "labels": labels} diff --git a/src/deepsparse/v2/image_classification/preprocess_operator.py b/src/deepsparse/v2/image_classification/preprocess_operator.py new file mode 100644 index 0000000000..9b4517a44c --- /dev/null +++ b/src/deepsparse/v2/image_classification/preprocess_operator.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Tuple + +import numpy +import onnx +from PIL import Image +from torchvision import transforms + +from deepsparse.image_classification.constants import ( + IMAGENET_RGB_MEANS, + IMAGENET_RGB_STDS, +) +from deepsparse.pipelines.computer_vision import ComputerVisionSchema +from deepsparse.v2.operators import Operator + + +class ImageClassificationInput(ComputerVisionSchema): + """ + Input model for image classification + """ + + +__all__ = ["ImageClassificationPreProcess"] + + +class ImageClassificationPreProcess(Operator): + """ + Image Classification pre-processing operator. This Operator is expected to process + the user inputs and prepare them for the engine. Inputs to this Operator are + expected to follow the ImageClassificationInput schema. + """ + + input_schema = ImageClassificationInput + output_schema = None + + def __init__(self, model_path: str, image_size: Optional[Tuple[int]] = None): + self.model_path = model_path + self._image_size = image_size or self._infer_image_size() + non_rand_resize_scale = 256.0 / 224.0 # standard used + self._pre_normalization_transforms = transforms.Compose( + [ + transforms.Resize( + tuple( + [ + round(non_rand_resize_scale * size) + for size in self._image_size + ] + ) + ), + transforms.CenterCrop(self._image_size), + ] + ) + + def run(self, inp: ImageClassificationInput, **kwargs) -> Dict: + """ + Pre-Process the Inputs for DeepSparse Engine + + :param inputs: input model + :return: list of preprocessed numpy arrays + """ + + if isinstance(inp.images, numpy.ndarray): + image_batch = inp.images + else: + if isinstance(inp.images, str): + inp.images = [inp.images] + + image_batch = list(map(self._preprocess_image, inp.images)) + + # build batch + image_batch = numpy.stack(image_batch, axis=0) + + original_dtype = image_batch.dtype + image_batch = numpy.ascontiguousarray(image_batch, dtype=numpy.float32) + + if original_dtype == numpy.uint8: + image_batch /= 255 + # normalize entire batch + image_batch -= numpy.asarray(IMAGENET_RGB_MEANS).reshape((-1, 3, 1, 1)) + image_batch /= numpy.asarray(IMAGENET_RGB_STDS).reshape((-1, 3, 1, 1)) + + return {"engine_inputs": [image_batch]} + + def _preprocess_image(self, image) -> numpy.ndarray: + if isinstance(image, List): + # image given as raw list + image = numpy.asarray(image) + if image.dtype == numpy.float32: + # image is already processed, append and continue + return image + # assume raw image input + # put image in PIL format for torchvision processing + image = image.astype(numpy.uint8) + if image.shape[0] < image.shape[-1]: + # put channel last + image = numpy.einsum("cwh->whc", image) + image = Image.fromarray(image) + elif isinstance(image, str): + # load image from string filepath + image = Image.open(image).convert("RGB") + elif isinstance(image, numpy.ndarray): + image = image.astype(numpy.uint8) + if image.shape[0] < image.shape[-1]: + # put channel last + image = numpy.einsum("cwh->whc", image) + image = Image.fromarray(image) + + if not isinstance(image, Image.Image): + raise ValueError( + f"inputs to {self.__class__.__name__} must be a string image " + "file path(s), a list representing a raw image, " + "PIL.Image.Image object(s), or a numpy array representing" + f"the entire pre-processed batch. Found {type(image)}" + ) + + # apply resize and center crop + image = self._pre_normalization_transforms(image) + image_numpy = numpy.array(image) + image.close() + + # make channel first dimension + image_numpy = image_numpy.transpose(2, 0, 1) + return image_numpy + + def _infer_image_size(self) -> Tuple[int, ...]: + """ + Infer and return the expected shape of the input tensor + + :return: The expected shape of the input tensor from onnx graph + """ + onnx_model = onnx.load(self.model_path) + input_tensor = onnx_model.graph.input[0] + return ( + input_tensor.type.tensor_type.shape.dim[2].dim_value, + input_tensor.type.tensor_type.shape.dim[3].dim_value, + ) diff --git a/src/deepsparse/v2/operators/__init__.py b/src/deepsparse/v2/operators/__init__.py new file mode 100644 index 0000000000..9d1a9812ac --- /dev/null +++ b/src/deepsparse/v2/operators/__init__.py @@ -0,0 +1,16 @@ +# flake8: noqa + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .operator import * diff --git a/src/deepsparse/v2/operators/engine_operator.py b/src/deepsparse/v2/operators/engine_operator.py new file mode 100644 index 0000000000..c2fc562c63 --- /dev/null +++ b/src/deepsparse/v2/operators/engine_operator.py @@ -0,0 +1,162 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Dict, List, Optional, Union + +from pydantic import BaseModel, Field + +from deepsparse import Context as EngineContext +from deepsparse import Engine, MultiModelEngine, Scheduler +from deepsparse.benchmark import ORTEngine +from deepsparse.utils import join_engine_outputs, model_to_path, split_engine_inputs +from deepsparse.v2.operators import Operator + + +DEEPSPARSE_ENGINE = "deepsparse" +ORT_ENGINE = "onnxruntime" + +SUPPORTED_PIPELINE_ENGINES = [DEEPSPARSE_ENGINE, ORT_ENGINE] + +__all__ = ["EngineOperator"] + + +class EngineOperatorInputs(BaseModel): + engine_inputs: List = Field(description="engine_inputs") + engine: Optional[Engine] = Field( + description="override the engine to run forward pass with", + default=None, + ) + + class Config: + arbitrary_types_allowed = True + + +class EngineOperatorOutputs(BaseModel): + engine_outputs: List = Field(description="engine outputs") + + +class EngineOperator(Operator): + input_schema = EngineOperatorInputs + output_schema = EngineOperatorOutputs + + def __init__( + self, + model_path: str, + engine_type: str = DEEPSPARSE_ENGINE, + num_cores: int = None, + num_streams: int = None, + scheduler: Scheduler = None, + input_shapes: List[List[int]] = None, + engine_context: Optional[EngineContext] = None, + engine_kwargs: Dict = None, + ): + self.model_path = model_to_path(model_path) + self._batch_size = 1 + self.engine_context = engine_context + + if self.engine_context is not None: + num_cores = num_cores or self.engine_context.num_cores + if self.engine_context.num_cores != num_cores: + raise ValueError( + f"num_cores mismatch. Expected {self.engine_context.num_cores} " + f"from passed context, but got {num_cores} while " + f"instantiating Pipeline" + ) + + engine_args = dict( + batch_size=self._batch_size, + num_cores=num_cores, + input_shapes=input_shapes, + ) + if engine_type.lower() == DEEPSPARSE_ENGINE: + engine_args["scheduler"] = scheduler + engine_args["num_streams"] = num_streams + + self._engine_args = engine_args + self._engine_type = engine_type + + if not engine_kwargs: + engine_kwargs = {} + + self.engine = self.create_engine(**engine_kwargs) + + @property + def batch_size(self) -> int: + """ + :return: the batch size this engine operator is compiled at + """ + return self._batch_size + + def create_engine( + self, + **kwargs, + ) -> Union[Engine, MultiModelEngine, ORTEngine]: + """ + Create an inference engine for a given ONNX model + + :param kwargs: overrides to engine_args used as kwargs for engine + constructor/compilation + :return: inference engine + """ + onnx_file_path = self.model_path + engine_args = deepcopy(self._engine_args) + engine_args.update(kwargs) + engine_type = self._engine_type.lower() + + if engine_type == DEEPSPARSE_ENGINE: + if self.engine_context is not None and isinstance( + self.engine_context, EngineContext + ): + engine_args.pop("num_cores", None) + engine_args.pop("scheduler", None) + engine_args.pop("num_streams", None) + engine_args["context"] = self.engine_context + return MultiModelEngine( + model=onnx_file_path, + **engine_args, + ) + engine_args.pop("cache_output_bools", None) + return Engine(onnx_file_path, **engine_args) + + if engine_type == ORT_ENGINE: + return ORTEngine(onnx_file_path, **engine_args) + + raise ValueError( + f"Unknown engine_type {engine_type}. Supported values include: " + f"{SUPPORTED_PIPELINE_ENGINES}" + ) + + def run(self, inp: EngineOperatorInputs, **kwargs) -> Dict: + if inp.engine: + # run with custom engine, do not split/join since custom engine + # may run at any batch size, returning here as code below has a + # planned refactor + engine_outputs = inp.engine(inp.engine_inputs) + return {"engine_outputs": engine_outputs} + inp = inp.engine_inputs + batches, orig_batch_size = self.expand_inputs(engine_inputs=inp) + batches_outputs = list(map(self.engine, batches)) + engine_outputs = self.condense_inputs( + batch_outputs=batches_outputs, orig_batch_size=orig_batch_size + ) + return {"engine_outputs": engine_outputs} + + def expand_inputs(self, **kwargs): + return split_engine_inputs(kwargs["engine_inputs"], self._batch_size) + + def condense_inputs(self, **kwargs): + batch_outputs = kwargs["batch_outputs"] + orig_batch_size = kwargs["orig_batch_size"] + return join_engine_outputs(batch_outputs, orig_batch_size) diff --git a/src/deepsparse/v2/operators/operator.py b/src/deepsparse/v2/operators/operator.py new file mode 100644 index 0000000000..b3963d8223 --- /dev/null +++ b/src/deepsparse/v2/operators/operator.py @@ -0,0 +1,136 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any, Optional, Type + +from pydantic import BaseModel + +from deepsparse.v2.utils import InferenceState, PipelineState + + +__all__ = ["Operator"] + + +class Operator(ABC): + """ + Base operator class - an operator should be defined for each atomic, functional + part of the pipeline. + """ + + # expected structured input and output types, to be defined by child classes + input_schema: Optional[Type[BaseModel]] = None + output_schema: Optional[Type[BaseModel]] = None + + @classmethod + def has_input_schema(cls) -> bool: + """ + :return: True if this class has a defined pydantic input schema + """ + if not cls.input_schema: + return False + + return issubclass(cls.input_schema, BaseModel) + + @classmethod + def has_output_schema(cls) -> bool: + """ + :return: True if this class has a defined pydantic input schema + """ + if not cls.output_schema: + return False + + return issubclass(cls.output_schema, BaseModel) + + def __call__( + self, + *args, + inference_state: InferenceState, + pipeline_state: PipelineState, + **kwargs, + ) -> Any: + """ + Parses inputs to this Operator and runs the run() method of this operator + + :param args: an unnamed arg may only be provided if it is of the type of the + input_schema + :param inference_state: inference_state for the pipeline. + :param pipeline_state: pipeline_state for the pipeline. The values in the state + are created during pipeline creation and are read-only during inference. + :param kwargs: kwargs when not initializing from an instantiated schema + :return: operator output + """ + if self.has_input_schema(): + if len(args) > 1: + raise ValueError( + f"The operator requires an {self.input_schema}. Too many arguments" + "provided." + ) + elif args and isinstance(args[0], self.input_schema): + inference_input = args[0] + elif kwargs: + inference_input = self.input_schema(**kwargs) + else: + raise ValueError( + "Can't resolve inputs. The values for the schema must be provided" + "in the form of a dictionary or an instance of the input_schema" + "object" + ) + run_output = self.run( + inference_input, + inference_state=inference_state, + pipeline_state=pipeline_state, + ) + else: + run_output = self.run( + *args, + inference_state=inference_state, + pipeline_state=pipeline_state, + **kwargs, + ) + + if self.has_output_schema(): + return self.output_schema(**run_output) + return run_output + + @abstractmethod + def run(self, *args, **kwargs) -> Any: + """ + :return: result of this operator as the defined output schema if applicable + """ + raise NotImplementedError + + def can_operate(self, inp: Any) -> bool: + """ + Whether or not the given operator can run, based on input + """ + return True + + def expand_inputs(self, **kwargs): + """ + Generic function to handle expanding values. + """ + raise NotImplementedError + + def condense_inputs(self, **kwargs): + """ + Generic function to handle condensing values. + """ + raise NotImplementedError + + def yaml(self): + pass + + def json(self): + pass diff --git a/src/deepsparse/v2/pipeline.py b/src/deepsparse/v2/pipeline.py new file mode 100644 index 0000000000..0a8c8b2f93 --- /dev/null +++ b/src/deepsparse/v2/pipeline.py @@ -0,0 +1,150 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Dict, List, Union + +from deepsparse.v2.operators import Operator +from deepsparse.v2.routers import Router +from deepsparse.v2.schedulers import OperatorScheduler, SchedulerGroup +from deepsparse.v2.utils import InferenceState, PipelineState + + +__all__ = ["Pipeline"] + + +class Pipeline(Operator): + """ + Pipeline accepts a series of operators, schedulers, and a router. Calling a pipeline + will use the router to run through all the defined operators. The operators should + be implemented using the Operator class and each implemented operator should be + responsible for a functional component of the pipelines. The flow of inputs/outputs + between the operators and the steps in the pipeline should be defined by the router, + (based off of the Router class), which dicates the next operator in the pipeline. + Execution of the operators will be handled by the provided schedulers. + + :param ops: Operators to run within the pipeline. Can either be a list of operators + or dictionary of operators. + :param router: A Router which dictates the next operator to call. + :param schedulers: A list of schedulers to run operators. + :param pipeline_state: pipeline_state created during pipeline initialization + + """ + + def __init__( + self, + ops: Union[Dict[str, Operator], List[Operator]], + router: Router, + schedulers: List[OperatorScheduler], + pipeline_state: PipelineState = None, + ): + + self.ops = ops + self.router = router + self.schedulers = schedulers + self.pipeline_state = pipeline_state + self.validate() + + # SchedulerGroup handles running all schedulers in order of priority + self._scheduler_group = SchedulerGroup(self.schedulers) + + def run( + self, + *args, + inference_state: InferenceState, + pipeline_state: PipelineState, + **kwargs, + ): + """ + Run through the operators using the provided router and scheduler. + The input to a given operator is the output of the previous operator. + + :param inference_state: inference_state for the pipeline. + :param pipeline_state: pipeline_state for the pipeline. The values in the state + are created during pipeline creation and are read-only during inference. + """ + next_step = self.router.START_ROUTE + operator_output = None + + while next_step != self.router.END_ROUTE: + # Either a dictionary key or valid index + operator = self.ops[next_step] + if next_step == self.router.START_ROUTE: + output_future = self._scheduler_group.submit( + *args, + inference_state=inference_state, + operator=operator, + pipeline_state=pipeline_state, + **kwargs, + ) + else: + if isinstance(operator_output, dict): + output_future = self._scheduler_group.submit( + inference_state=inference_state, + operator=operator, + pipeline_state=pipeline_state, + **operator_output, + ) + else: + output_future = self._scheduler_group.submit( + operator_output, + inference_state=inference_state, + pipeline_state=pipeline_state, + operator=operator, + ) + + operator_output = output_future.result() + if isinstance(operator_output, tuple): + state_update = operator_output[-1] + operator_output = operator_output[0] + inference_state.update_state(state_update) + + next_step = self.router.next(next_step, self.ops, operator_output) + + return operator_output + + def __call__(self, *args, **kwargs): + """ + Consolidate any provided inference_state or pipeline_state objects and pass + any other operator inputs to run(). + + :return: output of the pipeline operators ran with the router for the given + input + """ + if kwargs.get("inference_state"): + inference_state = kwargs.pop("inference_state") + else: + inference_state = InferenceState() + inference_state.create_state({}) + + if "pipeline_state" in kwargs: + self.pipeline_state = kwargs.get("pipeline_state") + + kwargs["inference_state"] = inference_state + kwargs["pipeline_state"] = self.pipeline_state + + return self.run(*args, **kwargs) + + def validate(self): + """ + Validate that compatability of the router and operators provided. + """ + router_validation = self.router.validate(self.ops) + + if router_validation is False: + # default error message + op_types = [type(op) for op in self.ops] + raise ValueError(f"Invalid Router: {type(self.router)} for ops: {op_types}") + elif isinstance(router_validation, str): + raise ValueError(f"Invalid Router for operators: {router_validation}") diff --git a/src/deepsparse/v2/routers/__init__.py b/src/deepsparse/v2/routers/__init__.py new file mode 100644 index 0000000000..8718bedeb4 --- /dev/null +++ b/src/deepsparse/v2/routers/__init__.py @@ -0,0 +1,17 @@ +# flake8: noqa + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .router import * diff --git a/src/deepsparse/v2/routers/router.py b/src/deepsparse/v2/routers/router.py new file mode 100644 index 0000000000..d1110d4ca7 --- /dev/null +++ b/src/deepsparse/v2/routers/router.py @@ -0,0 +1,152 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from abc import abstractmethod +from typing import Any, Dict, List, Optional, Union + +from deepsparse.v2.operators import Operator + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["Router", "LinearRouter", "GraphRouter"] + + +class Router: + """ + Routers dicate the next operator to run. Each Router must implement a next function, + which dictates the index or key of the next operator to run. + + :param start_route: the start index or key of the router + :param end_route: the end index or key of the router + :param route: the route that the router has to traverse through + + """ + + def __init__( + self, + end_route: Union[str, int], + start_route: Union[str, int], + route: Optional[Dict] = None, + ): + self.START_ROUTE = start_route + self.END_ROUTE = end_route + self.route = route + + @abstractmethod + def next( + self, + past: Union[str, int], + ops: Optional[Union[List[Operator], Dict[str, Operator]]], + inp: Optional[Any], + ) -> Union[str, int]: + """ + Determines the index or dictionary key for the next operator which should run. + + :param past: the previous index or key. This should uniquely determine the next + operator to run + :param ops: list or dictionary of operators + :param inp: operator input + :returns: the next index or dictionary key for the next operator to run + """ + raise NotImplementedError + + def yaml(self): + pass + + def json(self): + pass + + +class LinearRouter(Router): + """ + LinearRouterruns a list of Operators in sequential order. end_route should + be the length of the list and the start_route should be the start index. + """ + + def __init__(self, end_route: int, start_route: int = 0): + super().__init__(end_route=end_route, start_route=start_route) + + def next( + self, past: int, ops: Optional[List[Operator]] = None, inp: Optional[Any] = None + ) -> int: + new_index = past + 1 + if new_index < self.END_ROUTE: + return new_index + return self.END_ROUTE + + @staticmethod + def validate(operators: List[Operator]) -> bool: + """ + :param operators: operators that this Router could potentially run over + :return: True if this Router can run this series of operators. Base Router + runs any series of operators that is non empty and whose input and output + schemas align. If not valid, either False or an error string will be + returned + """ + if len(operators) < 1: + _LOGGER.info("No operators provided") + return False + + for idx in range(len(operators) - 1): + current_output_schema = operators[idx].output_schema + next_input_schema = operators[idx + 1].input_schema + + if current_output_schema is None or next_input_schema is None: + # if no input/output schema defined, assume operator can run + # without schema + continue + + if current_output_schema != next_input_schema: + _LOGGER.info( + f"Operator at idx {idx}: {type(operators[idx])} has invalid " + f"output schema {current_output_schema} for next operator " + f"{type(operators[idx + 1])} which requires {next_input_schema}" + ) + return False + return True + + +class GraphRouter(Router): + """ + Router for a DAG. Expects graphs be presented in the form of a dictionary, where + keys are the nodes of the graph and the values are the connected nodes. For + nodes with multiple ouput edges, all the nodes will be visited and the first node + where `can_operate` returns True will run. Paths should be deterministic. + """ + + def __init__(self, end_route: str, start_route: str, route: Dict): + super().__init__(end_route=end_route, start_route=start_route, route=route) + + def next( + self, + past: str, + ops: Dict[str, Operator], + inp: Any, + ) -> int: + node = past + if isinstance(self.route[node], str): + return self.route[node] + else: + for neighbour_node in self.route[node]: + neighbour_node_op = ops[neighbour_node] + if neighbour_node_op.can_operate(inp): + return neighbour_node + raise ValueError("Cannot operate on any of the nodes") + + @staticmethod + def validate(ops) -> bool: + pass diff --git a/src/deepsparse/v2/schedulers/__init__.py b/src/deepsparse/v2/schedulers/__init__.py new file mode 100644 index 0000000000..04c37077e1 --- /dev/null +++ b/src/deepsparse/v2/schedulers/__init__.py @@ -0,0 +1,18 @@ +# flake8: noqa + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .scheduler import * +from .scheduler_group import * diff --git a/src/deepsparse/v2/schedulers/scheduler.py b/src/deepsparse/v2/schedulers/scheduler.py new file mode 100644 index 0000000000..78a58e3389 --- /dev/null +++ b/src/deepsparse/v2/schedulers/scheduler.py @@ -0,0 +1,66 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from concurrent.futures import Future, ThreadPoolExecutor + +from deepsparse.v2.operators import Operator + + +__all__ = ["OperatorScheduler"] + + +class OperatorScheduler: + """ + OperatorSchedulers should implement a `submit` function that asynchronously + runs an operator and its input and returns a Future. Priority of operators + to run and resources they are run on are deferred to specific OperatorScheduler + implementations + + Base OperatorScheduler behaves as a simple queue deferring to ThreadPoolExecutor + + :param max_workers: maximum number of threads to execute at once + """ + + def __init__(self, max_workers: int = 1): + self._threadpool = ThreadPoolExecutor(max_workers=max_workers) + + def submit( + self, + *args, + operator: Operator, + **kwargs, + ) -> Future: + """ + :param operator: operator to run + :return: future referencing the asynchronously run output of the operator + """ + return self._threadpool.submit( + operator, + *args, + **kwargs, + ) + + def can_process( + self, + *args, + operator: Operator, + **kwargs, + ) -> bool: + """ + :param operator: operator to check + :return: True if this Operator can process the given operator and input. + Base OperatorScheduler always returns True + """ + return True diff --git a/src/deepsparse/v2/schedulers/scheduler_group.py b/src/deepsparse/v2/schedulers/scheduler_group.py new file mode 100644 index 0000000000..40b5695f22 --- /dev/null +++ b/src/deepsparse/v2/schedulers/scheduler_group.py @@ -0,0 +1,77 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from concurrent.futures import Future +from typing import List + +from deepsparse.v2.operators import Operator +from deepsparse.v2.schedulers.scheduler import OperatorScheduler + + +__all__ = ["SchedulerGroup"] + + +class SchedulerGroup(OperatorScheduler): + """ + Wrapper for a series of schedulers. Runs submitted operators on the first + scheduler that can process a given input + + :param schedulers: list of schedulers to pass operators to + """ + + def __init__(self, schedulers: List[OperatorScheduler]): + self.schedulers = schedulers + + def submit( + self, + *args, + operator: Operator, + **kwargs, + ) -> Future: + """ + :param operator: operator to run + :return: future referencing the asynchronously run output of the operator + """ + for scheduler in self.schedulers: + if scheduler.can_process( + *args, + operator=operator, + **kwargs, + ): + return scheduler.submit( + *args, + operator=operator, + **kwargs, + ) + + def can_process( + self, + *args, + operator: Operator, + **kwargs, + ) -> bool: + """ + :param operator: operator to check + :return: True if this Operator can process the given operator and input. + SchedulerGroup always returns True + """ + return any( + scheduler.can_process( + *args, + operator=operator, + **kwargs, + ) + for scheduler in self.schedulers + ) diff --git a/src/deepsparse/v2/text_generation/__init__.py b/src/deepsparse/v2/text_generation/__init__.py new file mode 100644 index 0000000000..21cd7e2acd --- /dev/null +++ b/src/deepsparse/v2/text_generation/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# flake8: noqa +from .autoregressive_preprocess_operator import * +from .compile_generated_tokens import * +from .compile_generations import * +from .compile_logits import * +from .generate_new_token import * +from .kv_cache_operator import * +from .multi_engine_prefill_operator import * +from .nl_engine_operator import * +from .prep_for_prefill import * +from .process_inputs import * +from .process_outputs import * + + +from .token_generator import * # isort:skip +from .prep_for_generation import * # isort:skip + +from .pipeline import * # isort:skip diff --git a/src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py b/src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py new file mode 100644 index 0000000000..6e97412e43 --- /dev/null +++ b/src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py @@ -0,0 +1,98 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any + +import numpy + +from deepsparse.transformers.utils.helpers import create_causal_mask +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import PipelineState + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["AutoRegressiveOperatorPreprocess"] + + +class AutoRegressiveOperatorPreprocess(Operator): + def __init__(self, sequence_length: int, prompt_sequence_length: int): + """ + Prepare the tokens for the single-token engine. This requires creating the + attention mask, positions, and causal mask. The output contains these three + arrays to be passed into the single-token engine. + """ + self.sequence_length = sequence_length + self.prompt_sequence_length = prompt_sequence_length + + _LOGGER.warn( + "This operator requires the PipelineState to be set-up with the " + "onnx_input_names_no_cache attribute set from the NLEngineOperator." + ) + + def can_operate(self, inp: Any) -> bool: + """ + Can run this Operator if the number of tokens left to process is greater than + 0 but less than the self.prompt_sequence_length. + """ + tokens = inp.get("tokens") + kv_cache = inp.get("kv_cache") + + if inp.get("in_generation"): + return True + + remaining_tokens = len(tokens) - kv_cache.total_num_processed_tokens + can_process = ( + remaining_tokens > 0 and remaining_tokens < self.prompt_sequence_length + ) + if can_process and inp.get("in_generation") is None: + return True + return False + + def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwargs): + kv_cache.set_capacity(self.sequence_length - 1) + + num_total_processed_tokens = kv_cache.total_num_processed_tokens + new_token = tokens[num_total_processed_tokens] + engine_input_names = pipeline_state.current_state.get( + "onnx_input_names_no_cache" + ) + + # padding is added to left, so attention mask is 1s from the + # right up to the number of total tokens (prompt + generated) + attention_mask = numpy.zeros((1, self.sequence_length), dtype=numpy.int64) + num_attention_entries_to_unmask = min( + num_total_processed_tokens + 1, self.sequence_length + ) # cap by seq len + attention_mask[:, -num_attention_entries_to_unmask:] = 1 + positions = numpy.array([[num_total_processed_tokens]], dtype=numpy.int64) + input_ids = numpy.array([[new_token]]) + causal_mask = create_causal_mask(input_ids, attention_mask) + + engine_inputs_map = dict( + input_ids=input_ids, + attention_mask=attention_mask, + causal_mask=causal_mask, + positions=positions, + ) + + engine_inputs = [engine_inputs_map[name] for name in engine_input_names] + + return { + "engine_inputs": engine_inputs, + "kv_cache": kv_cache, + "tokens": tokens, + "in_generation": kwargs.get("in_generation"), + } diff --git a/src/deepsparse/v2/text_generation/compile_generated_tokens.py b/src/deepsparse/v2/text_generation/compile_generated_tokens.py new file mode 100644 index 0000000000..c87436ab3a --- /dev/null +++ b/src/deepsparse/v2/text_generation/compile_generated_tokens.py @@ -0,0 +1,56 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import InferenceState + + +__all__ = ["CompileGeneratedTokens"] + + +class CompileGeneratedTokens(Operator): + def run( + self, + new_token, + logits, + finish_reason, + kv_cache, + tokens, + inference_state: InferenceState, + **kwargs, + ): + in_generation = True + + generated_tokens = inference_state.current_state.get("generated_tokens") + generated_logits = inference_state.current_state.get("generated_logits") + finished_reason = inference_state.current_state.get("finished_reason") + + generated_tokens.append(new_token) + generated_logits.append(logits) + finished_reason.append(finish_reason) + + if finish_reason is not None: + in_generation = False + + state_update = { # TODO: check if necessary + "finished_reason": finished_reason, + "generated_tokens": generated_tokens, + "generated_logits": generated_logits, + } + + output = { + "tokens": tokens, + "kv_cache": kv_cache, + "in_generation": in_generation, + } + return output, state_update diff --git a/src/deepsparse/v2/text_generation/compile_generations.py b/src/deepsparse/v2/text_generation/compile_generations.py new file mode 100644 index 0000000000..ed8297ac01 --- /dev/null +++ b/src/deepsparse/v2/text_generation/compile_generations.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import numpy +from pydantic import BaseModel, Field + +from deepsparse.transformers.pipelines.text_generation import FinishReason +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import InferenceState + + +__all__ = ["CompileGenerations", "CompileGenerationsOutput"] + + +class CompileGenerationsOutput(BaseModel): + generated_tokens: Any = Field(description="generated_tokens") + generated_logits: Any = Field(description="generated_logits") + finished_reason: Any = Field(description="finished_reason") + + +class CompileGenerations(Operator): + output_schema = CompileGenerationsOutput + + def can_operate(self, inp: Any): + if inp.get("in_generation") is False: + return True + return False + + def run(self, inference_state: InferenceState, **kwargs): + generated_tokens = inference_state.current_state.get("generated_tokens") + generated_logits = inference_state.current_state.get("generated_logits") + finished_reason = inference_state.current_state.get("finished_reason") + + if len(finished_reason) == 0: + finished_reason.append(FinishReason.LENGTH) + + generated_tokens = numpy.array([generated_tokens]) + generated_logits = numpy.concatenate(generated_logits, axis=1) + return { + "generated_tokens": generated_tokens, + "generated_logits": generated_logits, + "finished_reason": finished_reason, + } diff --git a/src/deepsparse/v2/text_generation/compile_logits.py b/src/deepsparse/v2/text_generation/compile_logits.py new file mode 100644 index 0000000000..21bd50e03e --- /dev/null +++ b/src/deepsparse/v2/text_generation/compile_logits.py @@ -0,0 +1,49 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import InferenceState + + +__all__ = ["CompilePromptLogits"] + + +class CompilePromptLogits(Operator): + """ + Combine the prompt logits. Currently relying on the inference state to store the + prompt logits for each token or multi-token batch processed. This operator will + take prompt logits from each iteration run and update the inference state. + """ + + def can_operate(self, inp: Any): + if inp.get("in_generation") is None: + return True + return False + + def run(self, logits, inference_state: InferenceState, **kwargs): + logit_type = "prompt_logits" + + if inference_state.current_state.get(logit_type) is not None: + current_logits = inference_state.current_state.get(logit_type).copy() + current_logits.append(logits) + else: + current_logits = [logits] + + state_update = {logit_type: current_logits} + return { + "kv_cache": kwargs.get("kv_cache"), + "tokens": kwargs.get("tokens"), + }, state_update diff --git a/src/deepsparse/v2/text_generation/generate_new_token.py b/src/deepsparse/v2/text_generation/generate_new_token.py new file mode 100644 index 0000000000..33ab546e39 --- /dev/null +++ b/src/deepsparse/v2/text_generation/generate_new_token.py @@ -0,0 +1,90 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Sequence, Union + +import transformers + +from deepsparse.transformers.pipelines.text_generation import FinishReason +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import InferenceState + + +__all__ = ["GenerateNewTokenOperator"] + + +class GenerateNewTokenOperator(Operator): + def __init__( + self, tokenizer: transformers.PreTrainedTokenizerBase, force_max_tokens: bool + ): + self.force_max_tokens = force_max_tokens + self.tokenizer = tokenizer + + def can_operate(self, inp: Any): + if inp.get("in_generation"): + return True + return False + + def run(self, logits, kv_cache, inference_state: InferenceState, **kwargs): + token_generator = inference_state.current_state.get("token_generator") + token = token_generator.generate(logits=logits[0, -1, :]) + finish_reason = None + + callback = inference_state.current_state.get("callback") + stop = inference_state.current_state.get("stop") + + if token == self.tokenizer.eos_token_id and not self.force_max_tokens: + finish_reason = FinishReason.STOP + + if self._stop_token_generated(token, stop_tokens=stop): + print( + "Stop token %s generated. Stopping generation." + % self.tokenizer.decode(token) + ) + finish_reason = FinishReason.STOP + + if callback is not None and callback(token) is False: + print( + "callback %s returned False, stopping generation." + % callback.__qualname__ + ) + finish_reason = FinishReason.CALLBACK + + max_tokens = inference_state.current_state.get("max_tokens") + if len(inference_state.current_state.get("generated_tokens")) + 1 == max_tokens: + finish_reason = inference_state.current_state.get("length_finish_reason") + + state_update = { + "token_generator": token_generator, + } + + new_generation = { + "logits": logits, + "new_token": token, + "finish_reason": finish_reason, + } + output = {"tokens": token_generator.tokens, "kv_cache": kv_cache} + output.update(new_generation) + return output, state_update + + def _stop_token_generated( + self, token, stop_tokens: Union[None, str, Sequence[str]] + ) -> bool: + if stop_tokens is None: + return False + + decoded_token = self.tokenizer.decode(token) + decoded_token = ( + decoded_token if decoded_token.isspace() else decoded_token.strip() + ) + return decoded_token in stop_tokens diff --git a/src/deepsparse/v2/text_generation/kv_cache_operator.py b/src/deepsparse/v2/text_generation/kv_cache_operator.py new file mode 100644 index 0000000000..0b232402b3 --- /dev/null +++ b/src/deepsparse/v2/text_generation/kv_cache_operator.py @@ -0,0 +1,70 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from pydantic import BaseModel, Field + +from deepsparse.transformers.utils import DecoderKVCache +from deepsparse.transformers.utils.helpers import ( + initialize_kv_cache_state, + prepends_bos_token, +) +from deepsparse.v2.operators import Operator + + +__all__ = ["KVCacheCreator"] + + +class KVCacheCreatorOutput(BaseModel): + kv_cache: Any = Field(description="KV Cache Created") # DecoderKVCache + + +class KVCacheCreatorInput(BaseModel): + cache_shape: Any = Field(description="shape") + kv_cache_data_type: Any = Field(description="data type") + output_names: Any = Field(description="output names") + + +class KVCacheCreator(Operator): + input_schema = KVCacheCreatorInput + output_schema = KVCacheCreatorOutput + + def __init__( + self, + tokenizer, + sequence_length: int, + prompt_sequence_length: int, + internal_kv_cache: bool, + ): + self.tokenizer = tokenizer + self.prompt_sequence_length = prompt_sequence_length + self.internal_kv_cache = internal_kv_cache + self.sequence_length = sequence_length + + def run(self, cache_shape, kv_cache_data_type: str, output_names: list, **kwargs): + kv_cache_state = initialize_kv_cache_state( + cache_shape=cache_shape, + kv_cache_data_type=kv_cache_data_type, + output_names=output_names, + length=self.sequence_length - self.prompt_sequence_length, + empty=bool(self.internal_kv_cache), + ) + + kv_cache = DecoderKVCache(self.internal_kv_cache) + kv_cache.setup( + state=kv_cache_state, + freeze_first_position=prepends_bos_token(self.tokenizer), + ) + return {"kv_cache": kv_cache} diff --git a/src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py b/src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py new file mode 100644 index 0000000000..9a885c2355 --- /dev/null +++ b/src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py @@ -0,0 +1,136 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from enum import Enum +from typing import Any + +import numpy + +from deepsparse.transformers.utils.helpers import create_causal_mask +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import PipelineState + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["MultiEnginePrefill"] + + +class OnnxInputNames(Enum): + INPUT_IDS = "input_ids" + ATTN_MASK = "attention_mask" + CAUSAL_MASK = "causal_mask" + POSITIONS = "positions" + + +# NOTE: A possible clean-up could involve combining this Operator and the +# autoregressive_preprocess_operator + + +class MultiEnginePrefill(Operator): + def __init__(self, prompt_sequence_length, sequence_length): + """ + Prepare the tokens for the multi-token engine. This requires creating the + attention mask, positions, and causal mask. The output contains these three + arrays to be passed into the multi-token engine. + """ + self.prompt_sequence_length = prompt_sequence_length + self.sequence_length = sequence_length + self.cases = { + OnnxInputNames.ATTN_MASK.value: self._case_attn_mask, + OnnxInputNames.POSITIONS.value: self._case_positions, + } + _LOGGER.warn( + "This operator requires the PipelineState to be set-up with the " + "onnx_input_names_no_cache attribute set from the NLEngineOperator." + ) + + def can_operate(self, inp: Any): + """ + Can only run if the number of prompt tokens left to process is greater than + or equal to the self.prompt_sequence_length. + """ + kv_cache = inp.get("kv_cache") + tokens = inp.get("tokens") + + if len(tokens) < self.prompt_sequence_length: + return False + + if ( + len(tokens) - kv_cache.total_num_processed_tokens + >= self.prompt_sequence_length + ): + return True + return False + + def _case_attn_mask(self, num_total_processed_tokens: int): + # create an empty attention mask + engine_input = numpy.zeros((1, self.sequence_length), dtype=numpy.int64) + # calculate the number of entries in attention mask that should be set to 1 + num_attention_entries_to_unmask = min( + num_total_processed_tokens + self.prompt_sequence_length, + self.sequence_length, + ) + engine_input[:, -num_attention_entries_to_unmask:] = 1 + return engine_input + + def _case_positions(self, num_total_processed_tokens: int): + return ( + numpy.arange( + num_total_processed_tokens, + num_total_processed_tokens + self.prompt_sequence_length, + ) + .reshape(1, -1) + .astype(numpy.int64) + ) + + def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwargs): + kv_cache.set_capacity(self.sequence_length - self.prompt_sequence_length) + + onnx_input_names_no_cache = pipeline_state.current_state.get( + "onnx_input_names_no_cache" + ) + + num_total_processed_tokens = kv_cache.total_num_processed_tokens + start = num_total_processed_tokens + end = start + self.prompt_sequence_length + token_batch = tokens[start:end] + + engine_inputs = [] + for name in onnx_input_names_no_cache: + if name == OnnxInputNames.INPUT_IDS.value: + engine_input = numpy.array([token_batch]) + elif ( + name == OnnxInputNames.ATTN_MASK.value + or name == OnnxInputNames.POSITIONS.value + ): + engine_input = self.cases[name](num_total_processed_tokens) + elif name == OnnxInputNames.CAUSAL_MASK.value: + continue + + engine_inputs.append(engine_input) + + if OnnxInputNames.CAUSAL_MASK.value in onnx_input_names_no_cache: + causal_mask = create_causal_mask( + input_ids=engine_inputs[0], + attention_mask=engine_inputs[1], + ) + engine_inputs.append(causal_mask) + + return { + "engine_inputs": engine_inputs, + "kv_cache": kv_cache, + "tokens": tokens, + } diff --git a/src/deepsparse/v2/text_generation/nl_engine_operator.py b/src/deepsparse/v2/text_generation/nl_engine_operator.py new file mode 100644 index 0000000000..0bd9098a40 --- /dev/null +++ b/src/deepsparse/v2/text_generation/nl_engine_operator.py @@ -0,0 +1,197 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +from typing import Any, List, Tuple + +from pydantic import BaseModel, Field + +from deepsparse.utils.onnx import ( + CACHE_INPUT_PREFIX, + overwrite_onnx_model_inputs_for_kv_cache_models, +) +from deepsparse.v2.operators.engine_operator import ( + DEEPSPARSE_ENGINE, + EngineOperator, + EngineOperatorInputs, +) + + +__all__ = ["NLEngineOperator"] + + +class NlEngineInput(BaseModel): + engine_inputs: List = Field(description="engine inputs") + kv_cache: Any = Field(description="kv_cache object") + tokens: List = Field(description="tokens") + in_generation: bool = Field(description="in_generation", default=None) + + +class NLEngineOperator(EngineOperator): + + """ + Operator for the NL Decoder Engine. This Operator inherits from the EngineOperator. + Specific updates to engine attributes are made through this operator, as well + as updating the kv_cache. This Operator is used for both the single-token and + multi-token case. + """ + + input_schema = NlEngineInput + output_schema = None + + def __init__( + self, + sequence_length: int, + input_ids_length: int, + internal_kv_cache: bool = False, + **kwargs, + ): + + self.kv_cache_data_type = None + ( + onnx_file_path, + output_indices_to_be_cached, + kv_cache_data_type, + ) = overwrite_onnx_model_inputs_for_kv_cache_models( + onnx_file_path=kwargs.get("model_path"), + batch_size=kwargs.get("batch_size", 1), + sequence_length=sequence_length, + input_ids_length=input_ids_length, + ) + + engine_kwargs = kwargs.get("engine_kwargs", {}) + if kwargs.get("engine_type", DEEPSPARSE_ENGINE) == DEEPSPARSE_ENGINE: + if "WAND_OPT_FLAGS" not in os.environ: + os.environ["WAND_OPT_FLAGS"] = "default,~pyramids" + + if any(output_indices_to_be_cached): + self.kv_cache_data_type = kv_cache_data_type + if ( + internal_kv_cache + and kwargs.get("engine_type", DEEPSPARSE_ENGINE) == DEEPSPARSE_ENGINE + ): + engine_kwargs["cached_outputs"] = output_indices_to_be_cached + + kwargs["engine_kwargs"] = engine_kwargs + kwargs["model_path"] = onnx_file_path + super().__init__(**kwargs) + + self.input_ids_length = input_ids_length + + def run(self, inp: NlEngineInput, **kwargs) -> Any: + engine_input = inp.engine_inputs + kv_cache = inp.kv_cache + + inputs = self._add_kv_cache_to_input(engine_input, kv_cache) + if bool(kv_cache.engine_internal_cache): + # conventionally, before dispatching + # inputs to the engine, we validate them + # if val_inp=True. However, in this case + # we want to pass the empty kv cache inputs + # (batch_size=0) to the engine. Therefore, + # we skip the validation + out = self.engine._eng_net.execute_list_out( + inputs, kv_cache.engine_internal_cache + ) + else: + # run the engine without the LIB.kv_cache object + out = ( + super() + .run(EngineOperatorInputs(engine_inputs=inputs), **kwargs) + .get("engine_outputs") + ) + + logits, *kv_cache_state = out + self._update_kv_cache( + kv_cache_state=kv_cache_state, + input_ids_len=self.input_ids_length, + kv_cache=kv_cache, + ) + + output = { + "logits": logits, + "kv_cache": kv_cache, + "tokens": inp.tokens, + "in_generation": inp.in_generation, + } + return output + + def _add_kv_cache_to_input(self, engine_input, kv_cache): + kv_cache_state = copy.copy(kv_cache.cached_inputs) + + for idx, input_name in enumerate(self.onnx_input_names_no_cache): + kv_cache_state[input_name] = engine_input[idx] + + new_inp = [kv_cache_state[name] for name in self.engine.input_names] + return new_inp + + def _update_kv_cache(self, kv_cache_state, input_ids_len, kv_cache): + if bool(kv_cache.engine_internal_cache): + kv_cache.total_num_processed_tokens += input_ids_len + return + + kv_cache_state = { + name: array + for name, array in zip(self.onnx_input_names_cached, kv_cache_state) + } + + kv_cache.update( + state=kv_cache_state, + input_ids_len=input_ids_len, + ) + + @property + def onnx_input_names_no_cache(self) -> List[str]: + """ + :return: The input names for the onnx model, excluding + the potential kv cache inputs + """ + return [ + name + for name in self.engine.input_names + if not name.startswith(CACHE_INPUT_PREFIX) + ] + + @property + def onnx_input_names_cached(self) -> List[str]: + """ + :return: The cached input names for the onnx model + """ + return [ + name + for name in self.engine.input_names + if name.startswith(CACHE_INPUT_PREFIX) + ] + + @property + def cache_shape(self) -> Tuple[int, int, int, int]: + """ + :return: The shape of the kv cache inputs + for the onnx model. The shape is + (batch_size, num_heads, sequence_length, hidden_size) + """ + cache_engine_input_index = next( + i + for i, name in enumerate(self.engine.input_names) + if CACHE_INPUT_PREFIX in name + ) + return self.engine.input_shapes[cache_engine_input_index] + + @property + def output_names(self) -> List[str]: + """ + :return: The output names for the onnx model + """ + return self.engine.output_names diff --git a/src/deepsparse/v2/text_generation/pipeline.py b/src/deepsparse/v2/text_generation/pipeline.py new file mode 100644 index 0000000000..fdb31f1c6c --- /dev/null +++ b/src/deepsparse/v2/text_generation/pipeline.py @@ -0,0 +1,180 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional + +from deepsparse.transformers.helpers import setup_transformers_pipeline +from deepsparse.transformers.utils.helpers import process_generation_config +from deepsparse.v2.pipeline import Pipeline +from deepsparse.v2.routers import GraphRouter +from deepsparse.v2.schedulers import OperatorScheduler +from deepsparse.v2.text_generation import ( + AutoRegressiveOperatorPreprocess, + CompileGeneratedTokens, + CompileGenerations, + CompilePromptLogits, + GenerateNewTokenOperator, + KVCacheCreator, + MultiEnginePrefill, + NLEngineOperator, + PrepareforPrefill, + PrepareGeneration, + ProcessInputsTextGeneration, + ProcessOutputs, + TokenGeneratorOperator, +) +from deepsparse.v2.utils import PipelineState + + +class TextGenerationPipeline(Pipeline): + def __init__( + self, + model_path: str, + prompt_sequence_length: int = 16, + sequence_length: int = 1024, + internal_kv_cache: bool = True, + force_max_tokens: bool = False, + generation_config=None, + engine_kwargs: Optional[Dict] = None, + ): + ( + self.model_path, + self.config, + self.tokenizer, + engine_kwargs, + ) = setup_transformers_pipeline( + model_path, sequence_length, engine_kwargs=engine_kwargs + ) + + pipeline_state = PipelineState() + pipeline_state_vals = {} + + if internal_kv_cache and engine_kwargs.get("engine_type") == "onnxruntime": + internal_kv_cache = False + + single_engine_operator = NLEngineOperator( + sequence_length=sequence_length, + internal_kv_cache=internal_kv_cache, + input_ids_length=1, + **engine_kwargs, + ) + + multi_engine_operator = NLEngineOperator( + sequence_length=sequence_length, + internal_kv_cache=internal_kv_cache, + input_ids_length=prompt_sequence_length, + **engine_kwargs, + ) + + # NOTE: Currently using pipeline state. Can swap to simply pass in the + # attributes to the specific Operator that need them, as class attributes. + pipeline_state_vals[ + "onnx_input_names_no_cache" + ] = single_engine_operator.onnx_input_names_no_cache + pipeline_state_vals["cache_shape"] = single_engine_operator.cache_shape + pipeline_state_vals["output_names"] = single_engine_operator.output_names + pipeline_state_vals[ + "kv_cache_data_type" + ] = single_engine_operator.kv_cache_data_type + pipeline_state.create_state(pipeline_state_vals) + + process_inputs = ProcessInputsTextGeneration( + generation_config=process_generation_config(generation_config), + sequence_length=sequence_length, + tokenizer=self.tokenizer, + ) + + kv_cache_creator = KVCacheCreator( + sequence_length=sequence_length, + tokenizer=self.tokenizer, + prompt_sequence_length=prompt_sequence_length, + internal_kv_cache=internal_kv_cache, + ) + + # NOTE: Can also have the KVCacheCreator be initialized inside this Operator. + # Relies on pipeline state variables set-up above (can be swapped to be class + # attributes instead of using the state. + engine_inputs_for_prefill = PrepareforPrefill(kv_cache_creator=kv_cache_creator) + + multi_engine_prefill = MultiEnginePrefill( + prompt_sequence_length=prompt_sequence_length, + sequence_length=sequence_length, + ) + compile_prompt_logits = CompilePromptLogits() + + autoregressive_preprocess = AutoRegressiveOperatorPreprocess( + sequence_length=sequence_length, + prompt_sequence_length=prompt_sequence_length, + ) + token_generator = TokenGeneratorOperator() + prep_for_generation = PrepareGeneration( + sequence_length=sequence_length, + prompt_sequence_length=prompt_sequence_length, + token_generator=token_generator, + ) + generate_new_token = GenerateNewTokenOperator( + tokenizer=self.tokenizer, force_max_tokens=force_max_tokens + ) + process_output = ProcessOutputs(tokenizer=self.tokenizer) + compile_generations = CompileGenerations() + compile_generated_tokens = CompileGeneratedTokens() + + ops = { + "process_input": process_inputs, + "single_engine": single_engine_operator, + "multi_engine": multi_engine_operator, + "kv_cache_creator": kv_cache_creator, + "prepare_prefill": engine_inputs_for_prefill, + "multi_engine_prefill": multi_engine_prefill, + "compile_logits": compile_prompt_logits, + "autoregressive_preprocess": autoregressive_preprocess, + "prep_for_generation": prep_for_generation, + "generate_new_token": generate_new_token, + "process_outputs": process_output, + "compile_generations": compile_generations, + "compile_generated_tokens": compile_generated_tokens, + } + + routes = { + "process_input": "prepare_prefill", + "prepare_prefill": ["multi_engine_prefill", "autoregressive_preprocess"], + "multi_engine_prefill": "multi_engine", + "multi_engine": "compile_logits", + "compile_logits": [ + "multi_engine_prefill", + "prep_for_generation", + "autoregressive_preprocess", + ], + "autoregressive_preprocess": "single_engine", + "single_engine": [ + "compile_logits", + "generate_new_token", + ], + "prep_for_generation": "autoregressive_preprocess", + "generate_new_token": "compile_generated_tokens", + "compile_generated_tokens": [ + "autoregressive_preprocess", + "compile_generations", + ], + "compile_generations": "process_outputs", + "process_outputs": "STOP", + } + + router = GraphRouter( + end_route="STOP", start_route="process_input", route=routes + ) + scheduler = [OperatorScheduler()] + super().__init__( + ops=ops, router=router, schedulers=scheduler, pipeline_state=pipeline_state + ) diff --git a/src/deepsparse/v2/text_generation/prep_for_generation.py b/src/deepsparse/v2/text_generation/prep_for_generation.py new file mode 100644 index 0000000000..544af43980 --- /dev/null +++ b/src/deepsparse/v2/text_generation/prep_for_generation.py @@ -0,0 +1,140 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import numpy + +from deepsparse.transformers.pipelines.text_generation import FinishReason +from deepsparse.v2.operators import Operator +from deepsparse.v2.text_generation import TokenGeneratorOperator +from deepsparse.v2.utils import InferenceState + + +__all__ = ["PrepareGeneration"] + + +class PrepareGeneration(Operator): + def __init__( + self, + token_generator: TokenGeneratorOperator, + prompt_sequence_length: int, + sequence_length: int, + ): + self.prompt_sequence_length = prompt_sequence_length + self.sequence_length = sequence_length + self.token_generator_creator = token_generator + + def can_operate(self, inp: Any): + kv_cache = inp.get("kv_cache") + tokens = inp.get("tokens") + + # If the number of prompt tokens is greater than what we've processed, + # don't start generation. Should be equal when started as all prompt logits + # should be accounted for and we should have updated the kv_cache for the single + # token engine. + if len(tokens) == kv_cache.total_num_processed_tokens: + return True + return False + + @staticmethod + def set_generated_length( + max_length: int, + prompt_tokens_length: int, + sequence_length: int, + prompt_sequence_length: int, + max_new_tokens: int, + finish_reason_choices: "FinishReason", # noqa + ): + """ + Determine the length of the generated tokens. The hard cap on the total number + of tokens is based on the sequence length. If max_length is provided and is less + than the sequence length, it will be used to cap the total number of tokens + generated. If it is not provided, the max_new_tokens attribute will be used and + also capped by the sequence length. + + :param max_length: max_length attribute, provided as input during inference + :param prompt_tokens_length: the number of prompt tokens used as part of the + generated output + :param sequence_length: the sequence length used for the pipeline + :param prompt_sequence_length: the prompt sequence length used for the pipeline + :param max_new_tokens: the max_new_tokens attribute, which may be provided + as part of the input during inference + """ + if max_length: + # if max_length provided, use that to cap total tokens generated + max_tokens = max_length + finish_reason = finish_reason_choices.LENGTH + else: + # if not provided, max tokens is based on max_new_tokens + prompt tokens + max_tokens = ( + min(max_new_tokens, sequence_length - prompt_sequence_length) + + prompt_tokens_length + ) + finish_reason = finish_reason_choices.MAX_NEW_TOKENS + + # hard model/pipeline cap + return ( + (sequence_length, finish_reason_choices.CAPACITY) + if sequence_length < max_tokens + else (max_tokens, finish_reason) + ) + + def run( + self, tokens: Any, kv_cache: Any, inference_state: InferenceState, **kwargs + ): + prompt_logits = inference_state.current_state.get("prompt_logits") + prompt_logits = numpy.concatenate(prompt_logits, axis=1) + # TODO: clean this up such that dont have to keep writing current_state + # everywhere + + generation_config = inference_state.current_state.get("generation_config") + include_prompt_logits = inference_state.current_state.get( + "include_prompt_logits" + ) + + token_generator_creator_output = self.token_generator_creator.run( + logits_shape=prompt_logits[0, -1, :].shape, + deterministic=not generation_config.do_sample, + sampling_temperature=generation_config.temperature, + tokens=tokens, + **inference_state.current_state, + ) + token_generator = token_generator_creator_output.get("token_generator") + token_generator.generate(prompt_logits[0, -1, :]) + + max_tokens, length_finish_reason = PrepareGeneration.set_generated_length( + max_length=generation_config.max_length, + prompt_tokens_length=1, + max_new_tokens=generation_config.max_new_tokens, + sequence_length=self.sequence_length, + prompt_sequence_length=self.prompt_sequence_length, + finish_reason_choices=FinishReason, + ) + state_update = { + "max_tokens": max_tokens, + "length_finish_reason": length_finish_reason, + "generated_tokens": [token_generator.tokens[-1]], + "generated_logits": [prompt_logits] + if include_prompt_logits + else [numpy.expand_dims(prompt_logits[:, -1, :], 0)], + "finished_reason": [], + "token_generator": token_generator, + } + + output = { + "tokens": token_generator.tokens, + "kv_cache": kv_cache, + "in_generation": True, + } + return output, state_update diff --git a/src/deepsparse/v2/text_generation/prep_for_prefill.py b/src/deepsparse/v2/text_generation/prep_for_prefill.py new file mode 100644 index 0000000000..2f9eb15797 --- /dev/null +++ b/src/deepsparse/v2/text_generation/prep_for_prefill.py @@ -0,0 +1,57 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any + +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import PipelineState + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["PrepareforPrefill"] + + +class PrepareforPrefill(Operator): + def __init__(self, kv_cache_creator: Operator): + """ + Operator before prefill. Responsible for creating the kv_cache based on engine + variables. Currently, this operator expects that the kv_cache_creator is + provided during initization and then uses pipeline_state to run the + kv_cache_operator. + """ + # NOTE: Alternatively, we can initialize the kv_cache_creater operator here, + # instead of at the pipeline level. + self.kv_cache_creator = kv_cache_creator + + _LOGGER.warn( + "This operator requires the PipelineState to be set-up with the " + "cache_shape, output_names, kv_cache_data_type attributes to be set " + "from the NLEngineOperator" + ) + + def run(self, tokens: Any, pipeline_state: PipelineState, **kwargs): + # NOTE: Can potentially just be class attributes instead of relying on + # pipeline state. + cache_shape = pipeline_state.current_state.get("cache_shape") + data_type = pipeline_state.current_state.get("kv_cache_data_type") + output_names = pipeline_state.current_state.get("output_names") + + kv_cache = self.kv_cache_creator.run( + cache_shape=cache_shape, + kv_cache_data_type=data_type, + output_names=output_names, + ).get("kv_cache") + return {"tokens": tokens, "kv_cache": kv_cache} diff --git a/src/deepsparse/v2/text_generation/process_inputs.py b/src/deepsparse/v2/text_generation/process_inputs.py new file mode 100644 index 0000000000..e57e402983 --- /dev/null +++ b/src/deepsparse/v2/text_generation/process_inputs.py @@ -0,0 +1,121 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib +from typing import Dict, Union + +import transformers + +from deepsparse.transformers.pipelines.text_generation import TextGenerationInput +from deepsparse.transformers.utils.helpers import ( + check_and_return_generation_config, + override_config, + repeat_inputs, +) +from deepsparse.v2.operators import Operator + + +class GenerationDefaults: + num_return_sequences = 1 + max_length = 100 + max_new_tokens = None + output_scores = False + top_k = 0 + top_p = 0.0 + repetition_penalty = 0.0 + do_sample = False + temperature = 1.0 + + +__all__ = ["ProcessInputsTextGeneration"] + + +class ProcessInputsTextGeneration(Operator): + """ + Input processing operator. Responsible for tokenizing the input, handling the + generation_config (if provided), updating the inference_state for later use, + and returning the tokens for prompt inferece. The expected input is defined by + the input_schema, which for this operator is TextGeneratioInput. + """ + + input_schema = TextGenerationInput + + def __init__( + self, + tokenizer: transformers.PreTrainedTokenizerBase, + generation_config: Union[ + str, pathlib.Path, Dict, transformers.GenerationConfig + ], + sequence_length: int, + ): + self.generation_config = generation_config + self.tokenizer = tokenizer + self.sequence_length = sequence_length + + def run(self, inp: TextGenerationInput, **kwargs): + generation_config = check_and_return_generation_config( + self.generation_config, inp.generation_config, GenerationDefaults() + ) + + generation_config = override_config(inp.generation_kwargs, generation_config) + + original_inputs = inp.sequences + if generation_config.num_return_sequences > 1: + if isinstance(inp.sequences, str): + inp.sequences = [inp.sequences] + inp.sequences = repeat_inputs( + inp.sequences, generation_config.num_return_sequences + ) + + if inp.fixed_sequences_length: + # to enforce a fixed sequence length, we need to + # truncate the input to the maximum sequence length + # or/and pad it to the maximum sequence length + truncate, padding = True, "max_length" + else: + # otherwise, we do not need to truncate the input + # and we shall can pad it to the longest sequence + # in the batch (so that the engine can process multiple inputs + # at once) + truncate, padding = False, "longest" + + input_tokens = self.tokenizer( + inp.sequences, + return_tensors="np", + max_length=self.sequence_length, + padding=padding, + truncation=truncate, + ) + + input_ids = input_tokens["input_ids"] + attention_mask = input_tokens["attention_mask"] + + inference_state_update = dict( + prompts=original_inputs, + streaming=inp.streaming, + generation_config=generation_config, + include_prompt_logits=inp.include_prompt_logits, + callback=inp.callback, + stop=inp.stop, + top_p=generation_config.top_p, + top_k=generation_config.top_k, + presence_penalty=inp.presence_penalty, + frequency_penalty=generation_config.repetition_penalty, + ) + + # TODO: move this step to prep_for_prefill and add attention mask to the output + # this will allow us to split/join more easily when processing multiple prompts + # in parallel + tokens = input_ids[attention_mask.nonzero()].tolist() + return {"tokens": tokens}, inference_state_update diff --git a/src/deepsparse/v2/text_generation/process_outputs.py b/src/deepsparse/v2/text_generation/process_outputs.py new file mode 100644 index 0000000000..ca1cf78521 --- /dev/null +++ b/src/deepsparse/v2/text_generation/process_outputs.py @@ -0,0 +1,88 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +from typing import Optional + +import numpy + +from deepsparse.transformers.pipelines.text_generation import ( + FinishReason, + GeneratedText, + TextGenerationOutput, +) +from deepsparse.v2.operators import Operator +from deepsparse.v2.text_generation.compile_generations import CompileGenerationsOutput +from deepsparse.v2.utils import InferenceState + + +class ProcessOutputs(Operator): + output_schema = TextGenerationOutput + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def _create_generated_text_output( + self, + sequence: str, + finish_reason: Optional[FinishReason] = None, + logits: Optional[numpy.array] = None, + ): + if finish_reason: + return GeneratedText( + text=sequence, + score=logits, + finished=True, + finished_reason=finish_reason.value, + ) + return GeneratedText( + text=sequence, + score=logits, + finished=False, + ) + + def run( + self, inp: CompileGenerationsOutput, inference_state: InferenceState, **kwargs + ): + generation_config = inference_state.current_state.get("generation_config") + generated_tokens = inp.generated_tokens + generated_logits = ( + inp.generated_logits if generation_config.output_scores else None + ) + finished_reason = inp.finished_reason + sequences = self.tokenizer.batch_decode( + generated_tokens, skip_special_tokens=True + ) + + finished_reason = [f for f in finished_reason if f] + + if generated_logits is not None: + generations = list( + map( + self._create_generated_text_output, + sequences, + finished_reason, + generated_logits, + ) + ) + else: + generations = list( + map(self._create_generated_text_output, sequences, finished_reason) + ) + outputs = dict( + created=datetime.datetime.now(), + prompts=inference_state.current_state.get("prompts"), + generations=generations, + ) + + return outputs diff --git a/src/deepsparse/v2/text_generation/token_generator.py b/src/deepsparse/v2/text_generation/token_generator.py new file mode 100644 index 0000000000..9148d71cc8 --- /dev/null +++ b/src/deepsparse/v2/text_generation/token_generator.py @@ -0,0 +1,30 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from deepsparse.transformers.utils.token_generator import TokenGenerator +from deepsparse.v2.operators import Operator + + +__all__ = ["TokenGeneratorOperator"] + + +class TokenGeneratorOperator(Operator): + def run(self, logits_shape, deterministic, tokens, sampling_temperature, **kwargs): + token_generator = TokenGenerator( + logits_shape=logits_shape, + deterministic=deterministic, + tokens=tokens, + sampling_temperature=sampling_temperature, + **kwargs, + ) + return {"token_generator": token_generator} diff --git a/src/deepsparse/v2/utils/__init__.py b/src/deepsparse/v2/utils/__init__.py new file mode 100644 index 0000000000..358405d7af --- /dev/null +++ b/src/deepsparse/v2/utils/__init__.py @@ -0,0 +1,17 @@ +# flake8: noqa + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .state import * +from .types import * diff --git a/src/deepsparse/v2/utils/state.py b/src/deepsparse/v2/utils/state.py new file mode 100644 index 0000000000..b54b890acf --- /dev/null +++ b/src/deepsparse/v2/utils/state.py @@ -0,0 +1,64 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from abc import ABC +from typing import Any, Union + + +__all__ = ["State", "PipelineState", "InferenceState"] + + +class State(ABC): + """ + Abstract class to store pipeline-level and inference-level state variables which + are generated by some Operator, and required by some other Operator. + """ + + def __init__(self): + self._current_state = None + + @property + def current_state(self): + return self._current_state + + +class PipelineState(State): + """ + Created during pipeline initialization. Pipeline state values are ready-only + duirng inference. + """ + + def create_state(self, new_state: dict): + if self._current_state: + raise ValueError("State creation is only allowed during initialization.") + self._current_state = new_state + + +class InferenceState(State): + """ + Inference state, created during every inference run. + """ + + def create_state(self, new_state: dict): + if self._current_state: + warnings.warn("Current state already exists, overriding.") + self._current_state = new_state + + def update_value(self, attribute: str, value: Union[str, int, list]): + if not self._current_state.get(attribute): + raise ValueError(f"{attribute} is not a valid state attribute") + self._current_state[attribute] = value + + def update_state(self, value: Any): + self._current_state.update(value) diff --git a/src/deepsparse/v2/utils/types.py b/src/deepsparse/v2/utils/types.py new file mode 100644 index 0000000000..3e4b974453 --- /dev/null +++ b/src/deepsparse/v2/utils/types.py @@ -0,0 +1,28 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Types to support deepsparse pipelines +""" + +from typing import Any, Dict, Union + +from pydantic import BaseModel + + +__all__ = ["OperatorSchema"] + + +# Operator inputs and outputs may either be a pydantic base model or a dict of kwargs +OperatorSchema = Union[BaseModel, Dict[str, Any]] diff --git a/tests/deepsparse/v2/__init__.py b/tests/deepsparse/v2/__init__.py new file mode 100644 index 0000000000..0c44f887a4 --- /dev/null +++ b/tests/deepsparse/v2/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/deepsparse/v2/test_basic_pipeline.py b/tests/deepsparse/v2/test_basic_pipeline.py new file mode 100644 index 0000000000..bedddd537a --- /dev/null +++ b/tests/deepsparse/v2/test_basic_pipeline.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple example and test of a dummy pipeline +""" + +from typing import Dict + +from pydantic import BaseModel + +from deepsparse.v2 import Pipeline +from deepsparse.v2.operators import Operator +from deepsparse.v2.routers import LinearRouter +from deepsparse.v2.schedulers import OperatorScheduler + + +class IntSchema(BaseModel): + value: int + + +class AddOneOperator(Operator): + input_schema = IntSchema + output_schema = IntSchema + + def run(self, inp: IntSchema, **kwargs) -> Dict: + return {"value": inp.value + 1} + + +class AddTwoOperator(Operator): + input_schema = IntSchema + output_schema = IntSchema + + def run(self, inp: IntSchema, **kwargs) -> Dict: + return {"value": inp.value + 2} + + +AddThreePipeline = Pipeline( + ops=[AddOneOperator(), AddTwoOperator()], + router=LinearRouter(end_route=2), + schedulers=[OperatorScheduler()], +) + + +def test_run_simple_pipeline(): + pipeline_input = IntSchema(value=5) + pipeline_output = AddThreePipeline(pipeline_input) + + assert pipeline_output.value == 8 diff --git a/tests/deepsparse/v2/test_image_classification.py b/tests/deepsparse/v2/test_image_classification.py new file mode 100644 index 0000000000..03e2807454 --- /dev/null +++ b/tests/deepsparse/v2/test_image_classification.py @@ -0,0 +1,39 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy + +import pytest +from deepsparse.v2.image_classification import ImageClassificationPipeline +from deepsparse.v2.image_classification.preprocess_operator import ( + ImageClassificationInput, +) +from tests.deepsparse.pipelines.data_helpers import computer_vision + + +@pytest.fixture +def get_images(): + batch_size = 2 + images = computer_vision(batch_size=batch_size) + return images.get("images") + + +def test_image_classification(get_images): + model_path = ( + "zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned95-none" + ) + pipeline = ImageClassificationPipeline(model_path=model_path) + output = pipeline(ImageClassificationInput(images=get_images)) + assert output.labels == [[207], [670]] + assert numpy.allclose(output.scores, [[21.85], [17.33]], atol=0.01)