|
| 1 | +import hashlib |
| 2 | +import logging |
| 3 | +from collections import defaultdict |
| 4 | +from concurrent.futures import ThreadPoolExecutor |
| 5 | +from datetime import datetime |
| 6 | +from functools import partial |
| 7 | +from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union |
| 8 | + |
| 9 | +import supervision as sv |
| 10 | +from fastapi import BackgroundTasks |
| 11 | +from pydantic import BaseModel, ConfigDict, Field, field_validator |
| 12 | + |
| 13 | +from inference.core.cache.base import BaseCache |
| 14 | +from inference.core.env import DEVICE_ID |
| 15 | +from inference.core.managers.metrics import get_system_info |
| 16 | +from inference.core.roboflow_api import ( |
| 17 | + get_roboflow_workspace, |
| 18 | + send_inference_results_to_model_monitoring, |
| 19 | +) |
| 20 | +from inference.core.workflows.execution_engine.constants import ( |
| 21 | + CLASS_NAME_KEY, |
| 22 | + INFERENCE_ID_KEY, |
| 23 | + PREDICTION_TYPE_KEY, |
| 24 | +) |
| 25 | +from inference.core.workflows.execution_engine.entities.base import OutputDefinition |
| 26 | +from inference.core.workflows.execution_engine.entities.types import ( |
| 27 | + BOOLEAN_KIND, |
| 28 | + CLASSIFICATION_PREDICTION_KIND, |
| 29 | + INSTANCE_SEGMENTATION_PREDICTION_KIND, |
| 30 | + KEYPOINT_DETECTION_PREDICTION_KIND, |
| 31 | + OBJECT_DETECTION_PREDICTION_KIND, |
| 32 | + STRING_KIND, |
| 33 | + Selector, |
| 34 | +) |
| 35 | +from inference.core.workflows.prototypes.block import ( |
| 36 | + BlockResult, |
| 37 | + WorkflowBlock, |
| 38 | + WorkflowBlockManifest, |
| 39 | +) |
| 40 | + |
| 41 | +SHORT_DESCRIPTION = "Periodically report an aggregated sample of inference results to Roboflow Model Monitoring" |
| 42 | + |
| 43 | +LONG_DESCRIPTION = """ |
| 44 | +This block 📊 **transforms inference data reporting** to a whole new level by |
| 45 | +periodically aggregating and sending a curated sample of predictions to |
| 46 | +**[Roboflow Model Monitoring](https://docs.roboflow.com/deploy/model-monitoring)**. |
| 47 | +
|
| 48 | +#### ✨ Key Features |
| 49 | +* **Effortless Aggregation:** Collects and organizes predictions in-memory, ensuring only the most relevant |
| 50 | +and confident predictions are reported. |
| 51 | +
|
| 52 | +* **Customizable Reporting Intervals:** Choose how frequently (in seconds) data should be sent—ensuring |
| 53 | +optimal balance between granularity and resource efficiency. |
| 54 | +
|
| 55 | +* **Debug-Friendly Mode:** Fine-tune operations by enabling or disabling asynchronous background execution. |
| 56 | +
|
| 57 | +#### 🔍 Why Use This Block? |
| 58 | +
|
| 59 | +This block is a game-changer for projects relying on video processing in Workflows. |
| 60 | +With its aggregation process, it identifies the most confident predictions across classes and sends |
| 61 | +them at regular intervals in small messages to Roboflow backend - ensuring that video processing |
| 62 | +performance is impacted to the least extent. |
| 63 | +
|
| 64 | +Perfect for: |
| 65 | +
|
| 66 | +* Monitoring production line performance in real-time 🏭. |
| 67 | +
|
| 68 | +* Debugging and validating your model’s performance over time ⏱️. |
| 69 | +
|
| 70 | +* Providing actionable insights from inference workflows with minimal overhead 🔧. |
| 71 | +
|
| 72 | +
|
| 73 | +#### 🚨 Limitations |
| 74 | +
|
| 75 | +* The block is should not be relied on when running Workflow in `inference` server or via HTTP request to Roboflow |
| 76 | +hosted platform, as the internal state is not persisted in a memory that would be accessible for all requests to |
| 77 | +the server, causing aggregation to **only have a scope of single request**. We will solve that problem in future |
| 78 | +releases if proven to be serious limitation for clients. |
| 79 | +
|
| 80 | +* This block do not have ability to separate aggregations for multiple videos processed by `InferencePipeline` - |
| 81 | +effectively aggregating data for **all video feeds connected to single process running `InferencePipeline`**. |
| 82 | +""" |
| 83 | + |
| 84 | + |
| 85 | +class BlockManifest(WorkflowBlockManifest): |
| 86 | + model_config = ConfigDict( |
| 87 | + json_schema_extra={ |
| 88 | + "name": "Model Monitoring Inference Aggregator", |
| 89 | + "version": "v1", |
| 90 | + "short_description": SHORT_DESCRIPTION, |
| 91 | + "long_description": LONG_DESCRIPTION, |
| 92 | + "license": "Apache-2.0", |
| 93 | + "block_type": "sink", |
| 94 | + } |
| 95 | + ) |
| 96 | + type: Literal["roboflow_core/model_monitoring_inference_aggregator@v1"] |
| 97 | + predictions: Selector( |
| 98 | + kind=[ |
| 99 | + OBJECT_DETECTION_PREDICTION_KIND, |
| 100 | + INSTANCE_SEGMENTATION_PREDICTION_KIND, |
| 101 | + KEYPOINT_DETECTION_PREDICTION_KIND, |
| 102 | + CLASSIFICATION_PREDICTION_KIND, |
| 103 | + ] |
| 104 | + ) = Field( |
| 105 | + description="Reference data to extract property from", |
| 106 | + examples=["$steps.my_step.predictions"], |
| 107 | + ) |
| 108 | + frequency: Union[ |
| 109 | + int, |
| 110 | + Selector(kind=[STRING_KIND]), |
| 111 | + ] = Field( |
| 112 | + default=5, |
| 113 | + description="Frequency of reporting (in seconds). For example, if 5 is provided, the " |
| 114 | + "block will report an aggregated sample of predictions every 5 seconds.", |
| 115 | + examples=["3", "5"], |
| 116 | + ) |
| 117 | + unique_aggregator_key: str = Field( |
| 118 | + description="Unique key used internally to track the session of inference results reporting. " |
| 119 | + "Must be unique for each step in your Workflow.", |
| 120 | + examples=["session-1v73kdhfse"], |
| 121 | + json_schema_extra={"hidden": True}, |
| 122 | + ) |
| 123 | + fire_and_forget: Union[bool, Selector(kind=[BOOLEAN_KIND])] = Field( |
| 124 | + default=True, |
| 125 | + description="Boolean flag dictating if sink is supposed to be executed in the background, " |
| 126 | + "not waiting on status of registration before end of workflow run. Use `True` if best-effort " |
| 127 | + "registration is needed, use `False` while debugging and if error handling is needed", |
| 128 | + examples=[True], |
| 129 | + ) |
| 130 | + |
| 131 | + @field_validator("frequency") |
| 132 | + @classmethod |
| 133 | + def ensure_frequency_is_correct(cls, value: Any) -> Any: |
| 134 | + if isinstance(value, int) and value < 1: |
| 135 | + raise ValueError("`frequency` cannot be lower than 1.") |
| 136 | + return value |
| 137 | + |
| 138 | + @classmethod |
| 139 | + def describe_outputs(cls) -> List[OutputDefinition]: |
| 140 | + return [ |
| 141 | + OutputDefinition(name="error_status", kind=[BOOLEAN_KIND]), |
| 142 | + OutputDefinition(name="message", kind=[STRING_KIND]), |
| 143 | + ] |
| 144 | + |
| 145 | + @classmethod |
| 146 | + def get_execution_engine_compatibility(cls) -> Optional[str]: |
| 147 | + return ">=1.3.0,<2.0.0" |
| 148 | + |
| 149 | + |
| 150 | +class ParsedPrediction(BaseModel): |
| 151 | + class_name: str |
| 152 | + confidence: float |
| 153 | + inference_id: str |
| 154 | + model_type: str |
| 155 | + |
| 156 | + |
| 157 | +class PredictionsAggregator(object): |
| 158 | + |
| 159 | + def __init__(self): |
| 160 | + self._raw_predictions: List[Union[sv.Detections, dict]] = [] |
| 161 | + |
| 162 | + def collect(self, value: Union[sv.Detections, dict]) -> None: |
| 163 | + # TODO: push into global state, otherwise for HTTP server use, |
| 164 | + # state would at most have 1 prediction!!! |
| 165 | + self._raw_predictions.append(value) |
| 166 | + |
| 167 | + def get_and_flush(self) -> List[ParsedPrediction]: |
| 168 | + predictions = self._consolidate() |
| 169 | + self._raw_predictions = [] |
| 170 | + return predictions |
| 171 | + |
| 172 | + def _consolidate(self) -> List[ParsedPrediction]: |
| 173 | + formatted_predictions = [] |
| 174 | + for p in self._raw_predictions: |
| 175 | + formatted_predictions.extend(format_predictions_for_model_monitoring(p)) |
| 176 | + class_groups: Dict[str, List[ParsedPrediction]] = defaultdict(list) |
| 177 | + for prediction in formatted_predictions: |
| 178 | + class_name = prediction.class_name |
| 179 | + class_groups[class_name].append(prediction) |
| 180 | + representative_predictions = [] |
| 181 | + for class_name, predictions in class_groups.items(): |
| 182 | + predictions.sort(key=lambda x: x.confidence, reverse=True) |
| 183 | + representative_predictions.append(predictions[0]) |
| 184 | + return representative_predictions |
| 185 | + |
| 186 | + |
| 187 | +class ModelMonitoringInferenceAggregatorBlockV1(WorkflowBlock): |
| 188 | + |
| 189 | + def __init__( |
| 190 | + self, |
| 191 | + cache: BaseCache, |
| 192 | + api_key: Optional[str], |
| 193 | + background_tasks: Optional[BackgroundTasks], |
| 194 | + thread_pool_executor: Optional[ThreadPoolExecutor], |
| 195 | + ): |
| 196 | + if api_key is None: |
| 197 | + raise ValueError( |
| 198 | + "ModelMonitoringInferenceAggregator block cannot run without Roboflow API key. " |
| 199 | + "If you do not know how to get API key - visit " |
| 200 | + "https://docs.roboflow.com/api-reference/authentication#retrieve-an-api-key to learn how to " |
| 201 | + "retrieve one." |
| 202 | + ) |
| 203 | + self._api_key = api_key |
| 204 | + self._cache = cache |
| 205 | + self._background_tasks = background_tasks |
| 206 | + self._thread_pool_executor = thread_pool_executor |
| 207 | + self._predictions_aggregator = PredictionsAggregator() |
| 208 | + |
| 209 | + @classmethod |
| 210 | + def get_init_parameters(cls) -> List[str]: |
| 211 | + return ["api_key", "cache", "background_tasks", "thread_pool_executor"] |
| 212 | + |
| 213 | + @classmethod |
| 214 | + def get_manifest(cls) -> Type[WorkflowBlockManifest]: |
| 215 | + return BlockManifest |
| 216 | + |
| 217 | + def run( |
| 218 | + self, |
| 219 | + fire_and_forget: bool, |
| 220 | + predictions: Union[sv.Detections, dict], |
| 221 | + frequency: int, |
| 222 | + unique_aggregator_key: str, |
| 223 | + ) -> BlockResult: |
| 224 | + self._last_report_time_cache_key = f"workflows:steps_cache:roboflow_core/model_monitoring_inference_aggregator@v1:{unique_aggregator_key}:last_report_time" |
| 225 | + if predictions: |
| 226 | + self._predictions_aggregator.collect(predictions) |
| 227 | + if not self._is_in_reporting_range(frequency): |
| 228 | + return { |
| 229 | + "error_status": False, |
| 230 | + "message": "Not in reporting range, skipping report. (Ok)", |
| 231 | + } |
| 232 | + preds = self._predictions_aggregator.get_and_flush() |
| 233 | + registration_task = partial( |
| 234 | + send_to_model_monitoring_request, |
| 235 | + cache=self._cache, |
| 236 | + last_report_time_cache_key=self._last_report_time_cache_key, |
| 237 | + api_key=self._api_key, |
| 238 | + predictions=preds, |
| 239 | + ) |
| 240 | + error_status = False |
| 241 | + message = "Reporting happens in the background task" |
| 242 | + if fire_and_forget and self._background_tasks: |
| 243 | + self._background_tasks.add_task(registration_task) |
| 244 | + elif fire_and_forget and self._thread_pool_executor: |
| 245 | + self._thread_pool_executor.submit(registration_task) |
| 246 | + else: |
| 247 | + error_status, message = registration_task() |
| 248 | + self._cache.set(self._last_report_time_cache_key, datetime.now().isoformat()) |
| 249 | + return { |
| 250 | + "error_status": error_status, |
| 251 | + "message": message, |
| 252 | + } |
| 253 | + |
| 254 | + def _is_in_reporting_range(self, frequency: int) -> bool: |
| 255 | + now = datetime.now() |
| 256 | + last_report_time_str = self._cache.get(self._last_report_time_cache_key) |
| 257 | + if last_report_time_str is None: |
| 258 | + self._cache.set(self._last_report_time_cache_key, now.isoformat()) |
| 259 | + v = self._cache.get(self._last_report_time_cache_key) |
| 260 | + last_report_time = now |
| 261 | + else: |
| 262 | + last_report_time = datetime.fromisoformat(last_report_time_str) |
| 263 | + time_elapsed = int((now - last_report_time).total_seconds()) |
| 264 | + return time_elapsed >= int(frequency) |
| 265 | + |
| 266 | + |
| 267 | +# TODO: maybe make this a helper or decorator, it's used in multiple places |
| 268 | +def get_workspace_name( |
| 269 | + api_key: str, |
| 270 | + cache: BaseCache, |
| 271 | +) -> str: |
| 272 | + api_key_hash = hashlib.md5(api_key.encode("utf-8")).hexdigest() |
| 273 | + cache_key = f"workflows:api_key_to_workspace:{api_key_hash}" |
| 274 | + cached_workspace_name = cache.get(cache_key) |
| 275 | + if cached_workspace_name: |
| 276 | + return cached_workspace_name |
| 277 | + workspace_name_from_api = get_roboflow_workspace(api_key=api_key) |
| 278 | + cache.set(key=cache_key, value=workspace_name_from_api, expire=900) |
| 279 | + return workspace_name_from_api |
| 280 | + |
| 281 | + |
| 282 | +def send_to_model_monitoring_request( |
| 283 | + cache: BaseCache, |
| 284 | + last_report_time_cache_key: str, |
| 285 | + api_key: str, |
| 286 | + predictions: List[ParsedPrediction], |
| 287 | +) -> Tuple[bool, str]: |
| 288 | + workspace_id = get_workspace_name(api_key=api_key, cache=cache) |
| 289 | + try: |
| 290 | + inference_data = { |
| 291 | + "timestamp": datetime.now().isoformat(), |
| 292 | + "source": "workflow", |
| 293 | + "source_info": "ModelMonitoringInferenceAggregatorBlockV1", |
| 294 | + "inference_results": [], |
| 295 | + "device_id": DEVICE_ID, |
| 296 | + } |
| 297 | + system_info = get_system_info() |
| 298 | + if system_info: |
| 299 | + for key, value in system_info.items(): |
| 300 | + inference_data[key] = value |
| 301 | + inference_data["inference_results"] = [p.model_dump() for p in predictions] |
| 302 | + send_inference_results_to_model_monitoring( |
| 303 | + api_key, workspace_id, inference_data |
| 304 | + ) |
| 305 | + cache.set(last_report_time_cache_key, datetime.now().isoformat()) |
| 306 | + return ( |
| 307 | + False, |
| 308 | + "Data sent successfully", |
| 309 | + ) |
| 310 | + except Exception as error: |
| 311 | + logging.warning(f"Could not upload inference data. Reason: {error}") |
| 312 | + return ( |
| 313 | + True, |
| 314 | + f"Error while uploading inference data. Error type: {type(error)}. Details: {error}", |
| 315 | + ) |
| 316 | + |
| 317 | + |
| 318 | +def format_predictions_for_model_monitoring( |
| 319 | + predictions: Union[sv.Detections, dict], |
| 320 | +) -> List[ParsedPrediction]: |
| 321 | + results = [] |
| 322 | + if isinstance(predictions, sv.Detections): |
| 323 | + for detection in predictions: |
| 324 | + _, _, confidence, _, _, data = detection |
| 325 | + prediction = ParsedPrediction( |
| 326 | + class_name=data.get("class_name", ""), |
| 327 | + confidence=(confidence if confidence is not None else 0.0), |
| 328 | + inference_id=data.get(INFERENCE_ID_KEY, ""), |
| 329 | + model_type=data.get(PREDICTION_TYPE_KEY, ""), |
| 330 | + ) |
| 331 | + results.append(prediction) |
| 332 | + elif isinstance(predictions, dict): |
| 333 | + detections = predictions.get("predictions", []) |
| 334 | + prediction_type = predictions.get(PREDICTION_TYPE_KEY, "") |
| 335 | + inference_id = predictions.get(INFERENCE_ID_KEY, "") |
| 336 | + if isinstance(detections, list): |
| 337 | + for d in detections: |
| 338 | + pred_instance = ParsedPrediction( |
| 339 | + class_name=d.get(CLASS_NAME_KEY, ""), |
| 340 | + confidence=d.get("confidence", 0.0), |
| 341 | + inference_id=inference_id, |
| 342 | + model_type=prediction_type, |
| 343 | + ) |
| 344 | + results.append(pred_instance) |
| 345 | + elif isinstance(detections, dict): |
| 346 | + for class_name, details in detections.items(): |
| 347 | + pred_instance = ParsedPrediction( |
| 348 | + class_name=class_name, |
| 349 | + confidence=details.get("confidence", 0.0), |
| 350 | + inference_id=inference_id, |
| 351 | + model_type=prediction_type, |
| 352 | + ) |
| 353 | + results.append(pred_instance) |
| 354 | + return results |
0 commit comments