From f5fe888ef6f9793359bb0081db480a484f271cc5 Mon Sep 17 00:00:00 2001 From: Yeonsik Seo Date: Thu, 8 May 2025 09:20:19 +0900 Subject: [PATCH] Add vendor RBLN (#1) * Adding vendor type RBLN(Rebellions) NPU & document & PyTest for added feature. --------- Co-authored-by: rebel-jonghewk <142865404+rebel-jonghewk@users.noreply.github.com> Co-authored-by: Sungho Shin <87514200+rebel-shshin@users.noreply.github.com> --- docs/hardware_support/hardware_support.rst | 1 + docs/hardware_support/rbln_support.md | 47 +++++++ .../serve/device/AcceleratorVendor.java | 1 + .../org/pytorch/serve/device/SystemInfo.java | 6 +- .../pytorch/serve/device/utils/RblnUtil.java | 108 +++++++++++++++ test/pytest/test_data/rbln_compile/compile.py | 13 ++ .../test_data/rbln_compile/config.properties | 10 ++ .../test_data/rbln_compile/rbln_handler.py | 102 ++++++++++++++ test/pytest/test_rbln_serving.py | 129 ++++++++++++++++++ ts/metrics/system_metrics.py | 49 +++++++ 10 files changed, 465 insertions(+), 1 deletion(-) create mode 100644 docs/hardware_support/rbln_support.md create mode 100644 frontend/server/src/main/java/org/pytorch/serve/device/utils/RblnUtil.java create mode 100644 test/pytest/test_data/rbln_compile/compile.py create mode 100644 test/pytest/test_data/rbln_compile/config.properties create mode 100644 test/pytest/test_data/rbln_compile/rbln_handler.py create mode 100644 test/pytest/test_rbln_serving.py diff --git a/docs/hardware_support/hardware_support.rst b/docs/hardware_support/hardware_support.rst index 267525fc65..4f4b34f759 100644 --- a/docs/hardware_support/hardware_support.rst +++ b/docs/hardware_support/hardware_support.rst @@ -6,3 +6,4 @@ linux_aarch64 nvidia_mps Intel Extension for PyTorch + rbln_support diff --git a/docs/hardware_support/rbln_support.md b/docs/hardware_support/rbln_support.md new file mode 100644 index 0000000000..e43c0d9062 --- /dev/null +++ b/docs/hardware_support/rbln_support.md @@ -0,0 +1,47 @@ + ⚠️ Notice: Limited Maintenance + +This project is no longer actively maintained. While existing releases remain available, there are no planned updates, bug fixes, new features, or security patches. Users should be aware that vulnerabilities may not be addressed. + +# Rebellions Support + +RBLN (Rebellions) NPUs are fully compatible with TorchServe. Rebellions provides documentation and tutorials to help you easily get started using the RBLN NPU with TorchServe. + +- [Rebellions TorchServe supports document](https://docs.rbln.ai/software/model_serving/torchserve/torchserve.html) + +## Support Matrix +For details on supported features and compatibility of the RBLN NPU, see the `Support Matrix` below: +- [Support Matrix](https://docs.rbln.ai/supports/version_matrix.html) + +## Installation +Please refer to the Installation Guide in the Rebellions documentation for instructions on installing the Driver and RBLN SDK. +- [Installation Guide](https://docs.rbln.ai/getting_started/installation_guide.html) + +## TorchServe with RBLN NPUs + +To start properly, please refer to the Rebellions' TorchServe Tutorials. + +Tutorials are available for various models including `Image Classification (ResNet50)`, `Object Detection (YOLOv8)`, and `LLM (Llama3-8B)`. + +- [Tutorial - ResNet50](https://docs.rbln.ai/software/model_serving/torchserve/tutorial/resnet50.html) +- [Tutorial - YOLOv8](https://docs.rbln.ai/software/model_serving/torchserve/tutorial/yolov8.html) +- [Tutorial - Llama3-8B](https://docs.rbln.ai/software/model_serving/torchserve/tutorial/llama3-8B.html) + +## Docker + +Please refer to the `Docker Support` documentation. + +- [Docker Support](https://docs.rbln.ai/software/system_management/docker.html) + +## Multi Devices + +For monitoring NPU statistics and utilizing multiple RBLN NPUs, please refer to the `Device Management` documentation. + +- [Device Management](https://docs.rbln.ai/software/system_management/device_management.html) + +## Contact + +We are always interested in improving the utilization of the RBLN NPU and providing technical support. + +Please contact us through the following page: + +- [Contact Us](https://docs.rbln.ai/supports/contact_us.html) diff --git a/frontend/server/src/main/java/org/pytorch/serve/device/AcceleratorVendor.java b/frontend/server/src/main/java/org/pytorch/serve/device/AcceleratorVendor.java index 22fd1f5d68..eaf09dba8e 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/device/AcceleratorVendor.java +++ b/frontend/server/src/main/java/org/pytorch/serve/device/AcceleratorVendor.java @@ -5,5 +5,6 @@ public enum AcceleratorVendor { NVIDIA, INTEL, APPLE, + RBLN, UNKNOWN } diff --git a/frontend/server/src/main/java/org/pytorch/serve/device/SystemInfo.java b/frontend/server/src/main/java/org/pytorch/serve/device/SystemInfo.java index a26f85ef93..f43557cf85 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/device/SystemInfo.java +++ b/frontend/server/src/main/java/org/pytorch/serve/device/SystemInfo.java @@ -63,6 +63,8 @@ private IAcceleratorUtility createAcceleratorUtility() { return new XpuUtil(); case APPLE: return new AppleUtil(); + case RBLN: + return new RblnUtil(); default: return null; } @@ -107,7 +109,9 @@ public static AcceleratorVendor detectVendorType() { return AcceleratorVendor.INTEL; } else if (isCommandAvailable("system_profiler")) { return AcceleratorVendor.APPLE; - } else { + } else if (isCommandAvailable("rbln-stat")) { + return AcceleratorVendor.RBLN; + }else { return AcceleratorVendor.UNKNOWN; } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/device/utils/RblnUtil.java b/frontend/server/src/main/java/org/pytorch/serve/device/utils/RblnUtil.java new file mode 100644 index 0000000000..f4d820b522 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/device/utils/RblnUtil.java @@ -0,0 +1,108 @@ +package org.pytorch.serve.device.utils; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import org.pytorch.serve.device.Accelerator; +import org.pytorch.serve.device.AcceleratorVendor; +import org.pytorch.serve.device.interfaces.IAcceleratorUtility; +import org.pytorch.serve.device.interfaces.IJsonSmiParser; + +public class RblnUtil implements IAcceleratorUtility, IJsonSmiParser { + + @Override + public String getGpuEnvVariableName() { + return "RBLN_DEVICES"; + } + + @Override + public String[] getUtilizationSmiCommand() { + return new String[] { + "rbln-stat", "-j" + }; + } + + @Override + public ArrayList getAvailableAccelerators( + LinkedHashSet availableAcceleratorIds) { + String jsonOutput = IAcceleratorUtility.callSMI(getUtilizationSmiCommand()); + JsonObject rootObject = JsonParser.parseString(jsonOutput).getAsJsonObject(); + return jsonOutputToAccelerators(rootObject, availableAcceleratorIds); + } + + @Override + public ArrayList smiOutputToUpdatedAccelerators( + String smiOutput, LinkedHashSet parsedGpuIds) { + JsonObject rootObject = JsonParser.parseString(smiOutput).getAsJsonObject(); + return jsonOutputToAccelerators(rootObject, parsedGpuIds); + } + + @Override + public Accelerator jsonObjectToAccelerator(JsonObject gpuObject) { + String model = gpuObject.get("name").getAsString(); + if (!model.startsWith("RBLN")) { + return null; + } + int npuId = gpuObject.get("npu").getAsInt(); + float npuUtil = gpuObject.get("util").getAsFloat(); + long memoryTotal = gpuObject.getAsJsonObject("memory").get("total").getAsLong(); + long memoryUsed = gpuObject.getAsJsonObject("memory").get("used").getAsLong(); + + Accelerator accelerator = new Accelerator(model, AcceleratorVendor.RBLN, npuId); + + // Set additional information + accelerator.setUsagePercentage(npuUtil); + accelerator.setMemoryUtilizationPercentage((memoryUsed==0)?0f:(memoryUsed/(float)memoryTotal)); + accelerator.setMemoryUtilizationMegabytes((int)(memoryUsed/1024/1024)); + + return accelerator; + } + + @Override + public Integer extractAcceleratorId(JsonObject cardObject) { + Integer npuId = cardObject.get("npu").getAsInt(); + return npuId; + } + + @Override + public List extractAccelerators(JsonElement rootObject) { + List accelerators = new ArrayList<>(); + JsonArray devicesArray = + rootObject + .getAsJsonObject() + .get("devices") + .getAsJsonArray(); + + for (JsonElement elem : devicesArray){ + accelerators.add(elem.getAsJsonObject()); + } + + return accelerators; + } + + public ArrayList jsonOutputToAccelerators( + JsonObject rootObject, LinkedHashSet parsedAcceleratorIds) { + + ArrayList accelerators = new ArrayList<>(); + List acceleratorObjects = extractAccelerators(rootObject); + + int i=0; + for (JsonObject acceleratorObject : acceleratorObjects) { + Integer acceleratorId = extractAcceleratorId(acceleratorObject); + if (acceleratorId != null + && (parsedAcceleratorIds.isEmpty() + || parsedAcceleratorIds.contains(acceleratorId))) { + Accelerator accelerator = jsonObjectToAccelerator(acceleratorObject); + accelerators.add(accelerator); + } + i++; + } + + return accelerators; + } +} diff --git a/test/pytest/test_data/rbln_compile/compile.py b/test/pytest/test_data/rbln_compile/compile.py new file mode 100644 index 0000000000..cba911a4c5 --- /dev/null +++ b/test/pytest/test_data/rbln_compile/compile.py @@ -0,0 +1,13 @@ +import rebel +import torch +from torchvision.models import ResNet50_Weights, resnet50 + +weights = ResNet50_Weights.DEFAULT +model = resnet50(weights=weights) +model.eval() + +compiled_model = rebel.compile_from_torch( + model, + [("input", [1, 3, 224, 224], torch.float32)], +) +compiled_model.save("resnet50.rbln") diff --git a/test/pytest/test_data/rbln_compile/config.properties b/test/pytest/test_data/rbln_compile/config.properties new file mode 100644 index 0000000000..d40c3d64ef --- /dev/null +++ b/test/pytest/test_data/rbln_compile/config.properties @@ -0,0 +1,10 @@ +default_workers_per_model:1 + +models={\ + "resnet50":{\ + "1.0":{\ + "marName": "resnet50.mar",\ + "responseTimeout": 120\ + }\ + }\ +} diff --git a/test/pytest/test_data/rbln_compile/rbln_handler.py b/test/pytest/test_data/rbln_compile/rbln_handler.py new file mode 100644 index 0000000000..4610997428 --- /dev/null +++ b/test/pytest/test_data/rbln_compile/rbln_handler.py @@ -0,0 +1,102 @@ +# resnet50_handler.py +# + +import io +import os + +import PIL.Image as Image +import rebel # RBLN Runtime +import torch +from torchvision.models import ResNet50_Weights + +from ts.torch_handler.base_handler import BaseHandler + + +class Resnet50Handler(BaseHandler): + def __init__(self): + self._context = None + self.initialized = False + self.model = None + self.weights = None + + def initialize(self, context): + """ + Initialize model. This will be called during model loading time + :param context: Initial context contains model server system properties. + :return: + """ + self._context = context + # load the model, refer 'custom handler class' above for details + model_dir = context.system_properties.get("model_dir") + serialized_file = context.manifest["model"].get("serializedFile") + model_path = os.path.join(model_dir, serialized_file) + if not os.path.isfile(model_path): + raise RuntimeError( + f"[RBLN ERROR] File not found at the specified model_path({model_path})." + ) + + self.module = rebel.Runtime(model_path, tensor_type="pt") + self.weights = ResNet50_Weights.DEFAULT + self.initialized = True + + def preprocess(self, data): + """ + Transform raw input into model input data. + :param batch: list of raw requests, should match batch size + :return: list of preprocessed model input data + """ + input_data = data[0].get("data") + if input_data is None: + input_data = data[0].get("body") + assert input_data is not None, print( + "[RBLN][ERROR] Data not found with client request." + ) + if not isinstance(input_data, (bytes, bytearray)): + raise ValueError("[RBLN][ERROR] Preprocessed data is not binary data.") + + try: + image = Image.open(io.BytesIO(input_data)) + except Exception as e: + raise ValueError(f"[RBLN][ERROR]Invalid image data: {e}") + prep = self.weights.transforms() + batch = prep(image).unsqueeze(0) + preprocessed_data = batch.numpy() + + return torch.from_numpy(preprocessed_data) + + def inference(self, model_input): + """ + Internal inference methods + :param model_input: transformed model input data + :return: list of inference output in NDArray + """ + + model_output = self.module.run(model_input) + + return model_output + + def postprocess(self, inference_output): + """ + Return inference result. + :param inference_output: list of inference output + :return: list of predict results + """ + score, class_id = torch.topk(inference_output, 1, dim=1) + category_name = self.weights.meta["categories"][class_id] + return category_name + + def handle(self, data, context): + """ + Invoke by TorchServe for prediction request. + Do pre-processing of data, prediction using model and postprocessing of prediciton output + :param data: Input data for prediction + :param context: Initial context contains model server system properties. + :return: prediction output + """ + model_input = self.preprocess(data) + model_output = self.inference(model_input) + category_name = self.postprocess(model_output) + + print("[RBLN][INFO] Top1 category: ", category_name) + + return [{"result": category_name}] diff --git a/test/pytest/test_rbln_serving.py b/test/pytest/test_rbln_serving.py new file mode 100644 index 0000000000..9018b61919 --- /dev/null +++ b/test/pytest/test_rbln_serving.py @@ -0,0 +1,129 @@ +import glob +import importlib.util +import json +import os +import subprocess +import time +from pathlib import Path + +import pytest + +try: + import rebel # nopycln: import + + RBLN_AVAILABLE = True +except ImportError: + RBLN_AVAILABLE = False + +REQUIRED_PKG = ["torchvision", "torch"] + +CURR_FILE_PATH = Path(__file__).parent +RBLN_TEST_DATA_DIR = os.path.join(CURR_FILE_PATH, "test_data", "rbln_compile") + +COMPILE_FILE = os.path.join(RBLN_TEST_DATA_DIR, "compile.py") +HANDLER_FILE = os.path.join(RBLN_TEST_DATA_DIR, "rbln_handler.py") +CONFIG_PROPERTIES = os.path.join(RBLN_TEST_DATA_DIR, "config.properties") +SERIALIZED_FILE = os.path.join(RBLN_TEST_DATA_DIR, "resnet50.rbln") +MODEL_STORE_DIR = os.path.join(RBLN_TEST_DATA_DIR, "model_store") +MODEL_NAME = "resnet50" + + +@pytest.fixture(scope="session", autouse=True) +def install_pkgs(): + for package_name in REQUIRED_PKG: + if importlib.util.find_spec(package_name) is None: + print(f"Installing missing package: {package_name}") + try: + subprocess.check_call( + [sys.executable, "-m", "pip", "install", package_name] + ) + except subprocess.CalledProcessError as e: + pytest.fail(f"Fail to install package {package_name}") + print(f"Installing missing package: {package_name} - Installed.") + + +def ensure_package(import_name, package_name): + if importlib.util.find_spec(import_name) is None: + print(f"Fail to find RBLN package : {package_name}") + return False + return True + + +@pytest.mark.skipif(RBLN_AVAILABLE == False, reason='"rebel-compiler" is not installed') +class TestTorchRbln: + def teardown_class(self): + subprocess.run("torchserve --stop", shell=True, check=True) + time.sleep(10) + + def test_rbln_sdk_packages(self): + assert ensure_package("rebel", "rebel-compiler") == True + + def test_archive_model_artifact(self): + assert len(glob.glob(COMPILE_FILE)) == 1 + assert len(glob.glob(HANDLER_FILE)) == 1 + assert len(glob.glob(CONFIG_PROPERTIES)) == 1 + + subprocess.run( + f"cd {RBLN_TEST_DATA_DIR} && python3 {COMPILE_FILE}", shell=True, check=True + ) + subprocess.run(f"mkdir -p {MODEL_STORE_DIR}", shell=True, check=True) + + assert len(glob.glob(SERIALIZED_FILE)) == 1 + + subprocess.run( + f"torch-model-archiver --model-name {MODEL_NAME} --version 1.0 --handler {HANDLER_FILE} --serialized-file {SERIALIZED_FILE} --export-path {MODEL_STORE_DIR} -f", + shell=True, + check=True, + ) + assert len(glob.glob(os.path.join(MODEL_STORE_DIR, f"{MODEL_NAME}.mar"))) == 1 + + def test_start_torchserve(self): + subprocess.run( + f"torchserve --start --ncs --models {MODEL_NAME}.mar --model-store {MODEL_STORE_DIR} --ts-config {CONFIG_PROPERTIES} --disable-token-auth", + shell=True, + check=True, + ) + time.sleep(10) + assert len(glob.glob("logs/access_log.log")) == 1 + assert len(glob.glob("logs/model_log.log")) == 1 + assert len(glob.glob("logs/ts_log.log")) == 1 + + def test_server_status(self): + result = subprocess.run( + "curl http://localhost:8080/ping", + shell=True, + capture_output=True, + check=True, + ) + expected_server_status_str = '{"status": "Healthy"}' + expected_server_status = json.loads(expected_server_status_str) + assert json.loads(result.stdout) == expected_server_status + + def test_registered_model(self): + result = subprocess.run( + "curl http://localhost:8081/models", + shell=True, + capture_output=True, + check=True, + ) + expected_registered_model_str = ( + '{"models": [{"modelName": "resnet50", "modelUrl": "resnet50.mar"}]}' + ) + expected_registered_model = json.loads(expected_registered_model_str) + assert json.loads(result.stdout) == expected_registered_model + + def test_serve_inference(self): + if not Path(f"{RBLN_TEST_DATA_DIR}/tabby.jpg").exists(): + subprocess.run( + f"cd {RBLN_TEST_DATA_DIR} && wget https://rbln-public.s3.ap-northeast-2.amazonaws.com/images/tabby.jpg", + shell=True, + ) + result = subprocess.run( + f'cd {RBLN_TEST_DATA_DIR} && curl -X POST "http://127.0.0.1:8080/predictions/resnet50" -H "Content-Type: application/octet-stream" --data-binary @./tabby.jpg', + shell=True, + capture_output=True, + check=True, + ) + expected_result_str = '{"result":"tabby"}' + expected_result = json.loads(expected_result_str) + assert json.loads(result.stdout) == expected_result diff --git a/ts/metrics/system_metrics.py b/ts/metrics/system_metrics.py index 6c2becfcda..26e8dac892 100644 --- a/ts/metrics/system_metrics.py +++ b/ts/metrics/system_metrics.py @@ -51,6 +51,14 @@ def disk_available(): system_metrics.append(Metric("DiskAvailable", data, "GB", dimension)) +def is_rbln_supported(): + try: + import rebel # nopycln: import + except ImportError: + return False + return True + + def collect_gpu_metrics(num_of_gpus): """ Collect GPU metrics. Supports NVIDIA and AMD GPUs. @@ -102,6 +110,47 @@ def collect_gpu_metrics(num_of_gpus): mem_used = 0 gpu_mem_utilization = 0 gpu_utilization = None + elif is_rbln_supported(): + import json + import shutil + import subprocess + + mem_used = 0 + gpu_mem_utilization = 0 + gpu_utilization = 0 + rbln_stat = shutil.which("rbln-stat") + if rbln_stat: + try: + stat = subprocess.run( + [rbln_stat, "-j"], capture_output=True, text=True, check=True + ) + data = json.loads(stat.stdout) + devices = data.get("devices", []) + for device in devices: + npu = device.get("npu") + if device.get("npu") == str(gpu_index): + mem_used = int(device.get("memory", {}).get("used", "0")) + gpu_mem_utilization = round( + ( + mem_used + / float(device.get("memory", {}).get("total", "0")) + ) + * 100, + 2, + ) # Percentage + mem_used = mem_used // 1024**2 # Megabyte + gpu_utilization = float( + device.get("util", "0") + ) # Percentage + except Exception as e: + logging.error(f'Could not utilize "rbln-stat". {e}') + mem_used = 0 + gpu_mem_utilization = 0 + gpu_utilization = 0 + else: + mem_used = 0 + gpu_mem_utilization = 0 + gpu_utilization = 0 dimension_gpu = [ Dimension("Level", "Host"),