Skip to content

Commit ab42ece

Browse files
author
Tim Huff
committed
adding DetectorsConfig model
1 parent 240f124 commit ab42ece

File tree

3 files changed

+146
-19
lines changed

3 files changed

+146
-19
lines changed

src/groundlight/edge/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
11
from .config import (
22
DEFAULT,
33
DISABLED,
4-
EDGE_WITH_ESCALATION,
4+
EDGE_ANSWERS_WITH_ESCALATION,
55
NO_CLOUD,
6+
DetectorsConfig,
67
DetectorConfig,
8+
EdgeEndpointConfig,
79
EdgeInferenceConfig,
810
GlobalConfig,
9-
RootEdgeConfig,
11+
InferenceConfig,
1012
)
1113

1214
__all__ = [
1315
"DEFAULT",
1416
"DISABLED",
15-
"EDGE_WITH_ESCALATION",
17+
"EDGE_ANSWERS_WITH_ESCALATION",
1618
"NO_CLOUD",
19+
"DetectorsConfig",
1720
"DetectorConfig",
21+
"EdgeEndpointConfig",
1822
"EdgeInferenceConfig",
1923
"GlobalConfig",
20-
"RootEdgeConfig",
24+
"InferenceConfig",
2125
]

src/groundlight/edge/config.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union
1+
from typing import Optional, Union
22

33
from model import Detector
44
from pydantic import BaseModel, ConfigDict, Field, model_validator
@@ -16,7 +16,7 @@ class GlobalConfig(BaseModel):
1616
)
1717

1818

19-
class EdgeInferenceConfig(BaseModel):
19+
class InferenceConfig(BaseModel):
2020
"""
2121
Configuration for edge inference on a specific detector.
2222
"""
@@ -27,7 +27,7 @@ class EdgeInferenceConfig(BaseModel):
2727
enabled: bool = Field( # TODO investigate and update the functionality of this option
2828
default=True, description="Whether the edge endpoint should accept image queries for this detector."
2929
)
30-
api_token: str | None = Field(
30+
api_token: Union[str, None] = Field(
3131
default=None, description="API token used to fetch the inference model for this detector."
3232
)
3333
always_return_edge_prediction: bool = Field(
@@ -74,13 +74,12 @@ class DetectorConfig(BaseModel):
7474
edge_inference_config: str = Field(..., description="Config for edge inference.")
7575

7676

77-
class RootEdgeConfig(BaseModel):
77+
class DetectorsConfig(BaseModel):
7878
"""
79-
Root configuration for edge inference.
79+
Detector and inference-config mappings for edge inference.
8080
"""
8181

82-
global_config: GlobalConfig = Field(default_factory=GlobalConfig)
83-
edge_inference_configs: dict[str, EdgeInferenceConfig] = Field(default_factory=dict)
82+
edge_inference_configs: dict[str, InferenceConfig] = Field(default_factory=dict)
8483
detectors: list[DetectorConfig] = Field(default_factory=list)
8584

8685
@model_validator(mode="after")
@@ -91,17 +90,17 @@ def validate_inference_configs(self):
9190
return self
9291

9392
def add_detector(
94-
self, detector: Union[str, Detector], edge_inference_config: Union[str, EdgeInferenceConfig]
93+
self, detector: Union[str, Detector], edge_inference_config: Union[str, InferenceConfig]
9594
) -> None:
9695
detector_id = detector.id if isinstance(detector, Detector) else detector
9796
if any(d.detector_id == detector_id for d in self.detectors):
9897
raise ValueError(f"A detector with ID '{detector_id}' already exists.")
99-
if isinstance(edge_inference_config, EdgeInferenceConfig):
98+
if isinstance(edge_inference_config, InferenceConfig):
10099
config = edge_inference_config
101100
existing = self.edge_inference_configs.get(config.name)
102101
if existing is None:
103102
self.edge_inference_configs[config.name] = config
104-
elif existing is not config:
103+
elif existing != config:
105104
raise ValueError(f"A different inference config named '{config.name}' is already registered.")
106105
config_name = config.name
107106
else:
@@ -119,16 +118,59 @@ def add_detector(
119118
)
120119

121120

121+
class EdgeEndpointConfig(BaseModel):
122+
"""
123+
Top-level edge endpoint configuration.
124+
"""
125+
126+
global_config: GlobalConfig = Field(default_factory=GlobalConfig)
127+
edge_inference_configs: dict[str, InferenceConfig] = Field(default_factory=dict)
128+
detectors: list[DetectorConfig] = Field(default_factory=list)
129+
130+
@model_validator(mode="after")
131+
def validate_inference_configs(self):
132+
DetectorsConfig(
133+
edge_inference_configs=self.edge_inference_configs,
134+
detectors=self.detectors,
135+
)
136+
return self
137+
138+
def add_detector(
139+
self, detector: Union[str, Detector], edge_inference_config: Union[str, InferenceConfig]
140+
) -> None:
141+
detectors_config = DetectorsConfig(
142+
edge_inference_configs=self.edge_inference_configs,
143+
detectors=self.detectors,
144+
)
145+
detectors_config.add_detector(detector, edge_inference_config)
146+
self.edge_inference_configs = detectors_config.edge_inference_configs
147+
self.detectors = detectors_config.detectors
148+
149+
@classmethod
150+
def from_detectors_config(
151+
cls, detectors_config: "DetectorsConfig", global_config: Optional[GlobalConfig] = None
152+
) -> "EdgeEndpointConfig":
153+
copied_config = detectors_config.model_copy(deep=True)
154+
return cls(
155+
global_config=global_config or GlobalConfig(),
156+
edge_inference_configs=copied_config.edge_inference_configs,
157+
detectors=copied_config.detectors,
158+
)
159+
160+
161+
EdgeInferenceConfig = InferenceConfig
162+
163+
122164
# Preset inference configs matching the standard edge-endpoint defaults.
123-
DEFAULT = EdgeInferenceConfig(name="default")
124-
EDGE_WITH_ESCALATION = EdgeInferenceConfig(
125-
name="edge_with_escalation",
165+
DEFAULT = InferenceConfig(name="default")
166+
EDGE_ANSWERS_WITH_ESCALATION = InferenceConfig(
167+
name="edge_answers_with_escalation",
126168
always_return_edge_prediction=True,
127169
min_time_between_escalations=2.0,
128170
)
129-
NO_CLOUD = EdgeInferenceConfig(
171+
NO_CLOUD = InferenceConfig(
130172
name="no_cloud",
131173
always_return_edge_prediction=True,
132174
disable_cloud_escalation=True,
133175
)
134-
DISABLED = EdgeInferenceConfig(name="disabled", enabled=False)
176+
DISABLED = InferenceConfig(name="disabled", enabled=False)

test/unit/test_edge_config.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import pytest
2+
3+
from groundlight.edge import (
4+
DEFAULT,
5+
EDGE_ANSWERS_WITH_ESCALATION,
6+
DetectorsConfig,
7+
EdgeEndpointConfig,
8+
InferenceConfig,
9+
NO_CLOUD,
10+
)
11+
12+
13+
def test_edge_endpoint_config_is_not_subclass_of_detectors_config():
14+
assert not issubclass(EdgeEndpointConfig, DetectorsConfig)
15+
16+
17+
def test_add_detector_allows_equivalent_named_inference_config():
18+
detectors_config = DetectorsConfig()
19+
detectors_config.add_detector(
20+
"det_1",
21+
InferenceConfig(
22+
name="custom_config",
23+
always_return_edge_prediction=True,
24+
min_time_between_escalations=0.5,
25+
),
26+
)
27+
detectors_config.add_detector(
28+
"det_2",
29+
InferenceConfig(
30+
name="custom_config",
31+
always_return_edge_prediction=True,
32+
min_time_between_escalations=0.5,
33+
),
34+
)
35+
36+
assert len(detectors_config.detectors) == 2
37+
assert list(detectors_config.edge_inference_configs.keys()) == ["custom_config"]
38+
39+
40+
def test_add_detector_rejects_different_named_inference_config():
41+
detectors_config = DetectorsConfig()
42+
detectors_config.add_detector("det_1", InferenceConfig(name="custom_config"))
43+
44+
with pytest.raises(ValueError, match="different inference config named 'custom_config'"):
45+
detectors_config.add_detector(
46+
"det_2",
47+
InferenceConfig(name="custom_config", always_return_edge_prediction=True),
48+
)
49+
50+
51+
def test_edge_endpoint_config_add_detector_delegates_to_detectors_logic():
52+
config = EdgeEndpointConfig()
53+
config.add_detector("det_1", NO_CLOUD)
54+
config.add_detector("det_2", EDGE_ANSWERS_WITH_ESCALATION)
55+
config.add_detector("det_3", DEFAULT)
56+
57+
assert [detector.detector_id for detector in config.detectors] == ["det_1", "det_2", "det_3"]
58+
assert set(config.edge_inference_configs.keys()) == {"no_cloud", "edge_answers_with_escalation", "default"}
59+
60+
61+
def test_from_detectors_config_copies_detector_data():
62+
detectors_config = DetectorsConfig()
63+
detectors_config.add_detector("det_1", DEFAULT)
64+
65+
config = EdgeEndpointConfig.from_detectors_config(detectors_config)
66+
detectors_config.add_detector("det_2", DEFAULT)
67+
68+
assert len(config.detectors) == 1
69+
assert len(detectors_config.detectors) == 2
70+
71+
72+
def test_inference_config_validation_errors():
73+
with pytest.raises(ValueError, match="disable_cloud_escalation"):
74+
InferenceConfig(name="bad", disable_cloud_escalation=True)
75+
76+
with pytest.raises(ValueError, match="cannot be less than 0.0"):
77+
InferenceConfig(
78+
name="bad_escalation_interval",
79+
always_return_edge_prediction=True,
80+
min_time_between_escalations=-1.0,
81+
)

0 commit comments

Comments
 (0)