Skip to content

Commit

Permalink
Add unique aggregator key
Browse files Browse the repository at this point in the history
  • Loading branch information
robiscoding committed Nov 21, 2024
1 parent 38cee36 commit 4417c2f
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import logging
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime
from functools import partial
from typing import Any, List, Literal, Optional, Tuple, Type, Union

import supervision as sv
from fastapi import BackgroundTasks
from pydantic import ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator

from inference.core.cache.base import BaseCache
from inference.core.env import DEVICE_ID
Expand All @@ -18,6 +17,11 @@
get_roboflow_workspace,
send_inference_results_to_model_monitoring,
)
from inference.core.workflows.execution_engine.constants import (
CLASS_NAME_KEY,
INFERENCE_ID_KEY,
PREDICTION_TYPE_KEY,
)
from inference.core.workflows.execution_engine.entities.base import OutputDefinition
from inference.core.workflows.execution_engine.entities.types import (
BOOLEAN_KIND,
Expand All @@ -39,7 +43,8 @@
LONG_DESCRIPTION = """
This block periodically reports an aggregated sample of inference results to Roboflow Model Monitoring.
It aggregates predictions in memory between reports and then sends a representative sample of predictions at a regular interval specified by the `frequency` parameter.
It aggregates predictions in memory between reports and then sends a representative sample of predictions at a regular interval specified by the `frequency` parameter. It
creates a representative sample by selecting the prediction with the highest confidence for each class predicted.
This is particularly useful when using InferencePipeline, which doesn't automatically report results to Model Monitoring.
Expand Down Expand Up @@ -78,6 +83,11 @@ class BlockManifest(WorkflowBlockManifest):
description="Frequency of reporting (in seconds). For example, if 5 is provided, the block will report an aggregated sample of predictions every 5 seconds.",
examples=["3", "5"],
)
unique_aggregator_key: str = Field(
description="Unique key used internally to track the session of inference results reporting.",
examples=["session-1v73kdhfse"],
json_schema_extra={"hidden": True},
)
fire_and_forget: Union[bool, Selector(kind=[BOOLEAN_KIND])] = Field(
default=True,
description="Boolean flag dictating if sink is supposed to be executed in the background, "
Expand Down Expand Up @@ -105,8 +115,7 @@ def get_execution_engine_compatibility(cls) -> Optional[str]:
return ">=1.3.0,<2.0.0"


@dataclass
class Prediction(object):
class ParsedPrediction(BaseModel):
class_name: str
confidence: float
inference_id: str
Expand All @@ -122,11 +131,10 @@ def __init__(self):
def collect(self, value: Union[sv.Detections, dict]) -> None:
self._raw_predictions.append(value)

def _consolidate(self) -> List[Prediction]:
def _consolidate(self) -> List[ParsedPrediction]:
formatted_predictions = []
for detections in self._raw_predictions:
f = format_sv_detections_for_model_monitoring(detections)
formatted_predictions.extend(f)
for p in self._raw_predictions:
formatted_predictions.extend(format_predictions_for_model_monitoring(p))

class_groups = defaultdict(list)
for prediction in formatted_predictions:
Expand All @@ -136,11 +144,11 @@ def _consolidate(self) -> List[Prediction]:
representative_predictions = []
for class_name, predictions in class_groups.items():
predictions.sort(key=lambda x: x["confidence"], reverse=True)
representative_predictions.append(predictions[0])
representative_predictions.append(ParsedPrediction(**predictions[0]))

return representative_predictions

def get_and_flush(self) -> List[Prediction]:
def get_and_flush(self) -> List[ParsedPrediction]:
predictions = self._consolidate()
self._raw_predictions = []
return predictions
Expand All @@ -166,7 +174,6 @@ def __init__(
self._cache = cache
self._background_tasks = background_tasks
self._thread_pool_executor = thread_pool_executor
self._last_report_time_cache_key = "roboflow_model_monitoring_last_report_time"
self._predictions_aggregator = PredictionsAggregator()

@classmethod
Expand All @@ -182,7 +189,9 @@ def run(
fire_and_forget: bool,
predictions: Union[sv.Detections, dict],
frequency: int,
unique_aggregator_key: str,
) -> BlockResult:
self._last_report_time_cache_key = f"workflows:steps_cache:roboflow_core/model_monitoring_inference_aggregator@v1:{unique_aggregator_key}:last_report_time"
if predictions:
self._predictions_aggregator.collect(predictions)
if not self._is_in_reporting_range(frequency):
Expand Down Expand Up @@ -217,10 +226,12 @@ def _is_in_reporting_range(self, frequency: int) -> bool:
last_report_time_str = self._cache.get(self._last_report_time_cache_key)
if last_report_time_str is None:
self._cache.set(self._last_report_time_cache_key, now.isoformat())
v = self._cache.get(self._last_report_time_cache_key)
last_report_time = now
else:
last_report_time = datetime.fromisoformat(last_report_time_str)
time_elapsed = int((now - last_report_time).total_seconds())

return time_elapsed >= int(frequency)


Expand All @@ -243,7 +254,7 @@ def send_to_model_monitoring_request(
cache: BaseCache,
last_report_time_cache_key: str,
api_key: str,
predictions: List[Prediction],
predictions: List[ParsedPrediction],
) -> Tuple[bool, str]:
workspace_id = get_workspace_name(api_key=api_key, cache=cache)
try:
Expand All @@ -259,7 +270,7 @@ def send_to_model_monitoring_request(
if system_info:
for key, value in system_info.items():
inference_data[key] = value
inference_data["inference_results"] = predictions
inference_data["inference_results"] = [p.model_dump() for p in predictions]
send_inference_results_to_model_monitoring(
api_key, workspace_id, inference_data
)
Expand All @@ -276,42 +287,40 @@ def send_to_model_monitoring_request(
)


def format_sv_detections_for_model_monitoring(
detections: Union[sv.Detections, dict],
) -> List[Prediction]:
def format_predictions_for_model_monitoring(
predictions: Union[sv.Detections, dict],
) -> List[ParsedPrediction]:
results = []
if isinstance(detections, sv.Detections):
num_detections = len(detections.data.get("detection_id", []))
for i in range(num_detections):
prediction = Prediction(
class_name=detections.data.get("class_name", [""])[i],
confidence=(
detections.confidence[i]
if detections.confidence is not None
else 0.0
),
inference_id=detections.data.get("inference_id", [""])[i],
model_type=detections.data.get("prediction_type", [""])[i],
if isinstance(predictions, sv.Detections):
for detection in predictions:
_, _, confidence, _, _, data = detection
prediction = ParsedPrediction(
class_name=data.get("class_name", ""),
confidence=(confidence if confidence is not None else 0.0),
inference_id=data.get(INFERENCE_ID_KEY, ""),
model_type=data.get(PREDICTION_TYPE_KEY, ""),
)
results.append(prediction.__dict__)
elif isinstance(detections, dict):
predictions = detections.get("predictions", [])
if isinstance(predictions, list):
for prediction in predictions:
pred_instance = Prediction(
class_name=prediction.get("class", ""),
confidence=prediction.get("confidence", 0.0),
inference_id=detections.get("inference_id", ""),
model_type=detections.get("prediction_type", ""),
elif isinstance(predictions, dict):
detections = predictions.get("predictions", [])
prediction_type = predictions.get(PREDICTION_TYPE_KEY, "")
inference_id = predictions.get(INFERENCE_ID_KEY, "")
if isinstance(detections, list):
for d in detections:
pred_instance = ParsedPrediction(
class_name=d.get(CLASS_NAME_KEY, ""),
confidence=d.get("confidence", 0.0),
inference_id=inference_id,
model_type=prediction_type,
)
results.append(pred_instance.__dict__)
elif isinstance(predictions, dict):
for class_name, details in predictions.items():
pred_instance = Prediction(
elif isinstance(detections, dict):
for class_name, details in detections.items():
pred_instance = ParsedPrediction(
class_name=class_name,
confidence=details.get("confidence", 0.0),
inference_id=detections.get("inference_id", ""),
model_type=detections.get("prediction_type", ""),
inference_id=inference_id,
model_type=prediction_type,
)
results.append(pred_instance.__dict__)
return results
Loading

0 comments on commit 4417c2f

Please sign in to comment.