1- from typing import Union
1+ from typing import Optional , Union
22
33from model import Detector
44from pydantic import BaseModel , ConfigDict , Field , model_validator
@@ -16,7 +16,7 @@ class GlobalConfig(BaseModel):
1616 )
1717
1818
19- class EdgeInferenceConfig (BaseModel ):
19+ class InferenceConfig (BaseModel ):
2020 """
2121 Configuration for edge inference on a specific detector.
2222 """
@@ -27,7 +27,7 @@ class EdgeInferenceConfig(BaseModel):
2727 enabled : bool = Field ( # TODO investigate and update the functionality of this option
2828 default = True , description = "Whether the edge endpoint should accept image queries for this detector."
2929 )
30- api_token : str | None = Field (
30+ api_token : Union [ str , None ] = Field (
3131 default = None , description = "API token used to fetch the inference model for this detector."
3232 )
3333 always_return_edge_prediction : bool = Field (
@@ -74,13 +74,12 @@ class DetectorConfig(BaseModel):
7474 edge_inference_config : str = Field (..., description = "Config for edge inference." )
7575
7676
77- class RootEdgeConfig (BaseModel ):
77+ class DetectorsConfig (BaseModel ):
7878 """
79- Root configuration for edge inference.
79+ Detector and inference-config mappings for edge inference.
8080 """
8181
82- global_config : GlobalConfig = Field (default_factory = GlobalConfig )
83- edge_inference_configs : dict [str , EdgeInferenceConfig ] = Field (default_factory = dict )
82+ edge_inference_configs : dict [str , InferenceConfig ] = Field (default_factory = dict )
8483 detectors : list [DetectorConfig ] = Field (default_factory = list )
8584
8685 @model_validator (mode = "after" )
@@ -91,17 +90,17 @@ def validate_inference_configs(self):
9190 return self
9291
9392 def add_detector (
94- self , detector : Union [str , Detector ], edge_inference_config : Union [str , EdgeInferenceConfig ]
93+ self , detector : Union [str , Detector ], edge_inference_config : Union [str , InferenceConfig ]
9594 ) -> None :
9695 detector_id = detector .id if isinstance (detector , Detector ) else detector
9796 if any (d .detector_id == detector_id for d in self .detectors ):
9897 raise ValueError (f"A detector with ID '{ detector_id } ' already exists." )
99- if isinstance (edge_inference_config , EdgeInferenceConfig ):
98+ if isinstance (edge_inference_config , InferenceConfig ):
10099 config = edge_inference_config
101100 existing = self .edge_inference_configs .get (config .name )
102101 if existing is None :
103102 self .edge_inference_configs [config .name ] = config
104- elif existing is not config :
103+ elif existing != config :
105104 raise ValueError (f"A different inference config named '{ config .name } ' is already registered." )
106105 config_name = config .name
107106 else :
@@ -119,16 +118,59 @@ def add_detector(
119118 )
120119
121120
121+ class EdgeEndpointConfig (BaseModel ):
122+ """
123+ Top-level edge endpoint configuration.
124+ """
125+
126+ global_config : GlobalConfig = Field (default_factory = GlobalConfig )
127+ edge_inference_configs : dict [str , InferenceConfig ] = Field (default_factory = dict )
128+ detectors : list [DetectorConfig ] = Field (default_factory = list )
129+
130+ @model_validator (mode = "after" )
131+ def validate_inference_configs (self ):
132+ DetectorsConfig (
133+ edge_inference_configs = self .edge_inference_configs ,
134+ detectors = self .detectors ,
135+ )
136+ return self
137+
138+ def add_detector (
139+ self , detector : Union [str , Detector ], edge_inference_config : Union [str , InferenceConfig ]
140+ ) -> None :
141+ detectors_config = DetectorsConfig (
142+ edge_inference_configs = self .edge_inference_configs ,
143+ detectors = self .detectors ,
144+ )
145+ detectors_config .add_detector (detector , edge_inference_config )
146+ self .edge_inference_configs = detectors_config .edge_inference_configs
147+ self .detectors = detectors_config .detectors
148+
149+ @classmethod
150+ def from_detectors_config (
151+ cls , detectors_config : "DetectorsConfig" , global_config : Optional [GlobalConfig ] = None
152+ ) -> "EdgeEndpointConfig" :
153+ copied_config = detectors_config .model_copy (deep = True )
154+ return cls (
155+ global_config = global_config or GlobalConfig (),
156+ edge_inference_configs = copied_config .edge_inference_configs ,
157+ detectors = copied_config .detectors ,
158+ )
159+
160+
161+ EdgeInferenceConfig = InferenceConfig
162+
163+
122164# Preset inference configs matching the standard edge-endpoint defaults.
123- DEFAULT = EdgeInferenceConfig (name = "default" )
124- EDGE_WITH_ESCALATION = EdgeInferenceConfig (
125- name = "edge_with_escalation " ,
165+ DEFAULT = InferenceConfig (name = "default" )
166+ EDGE_ANSWERS_WITH_ESCALATION = InferenceConfig (
167+ name = "edge_answers_with_escalation " ,
126168 always_return_edge_prediction = True ,
127169 min_time_between_escalations = 2.0 ,
128170)
129- NO_CLOUD = EdgeInferenceConfig (
171+ NO_CLOUD = InferenceConfig (
130172 name = "no_cloud" ,
131173 always_return_edge_prediction = True ,
132174 disable_cloud_escalation = True ,
133175)
134- DISABLED = EdgeInferenceConfig (name = "disabled" , enabled = False )
176+ DISABLED = InferenceConfig (name = "disabled" , enabled = False )
0 commit comments