Skip to content

Commit f25780c

Browse files
committed
feat: add PrimingGroup and MLPipeline SDK methods
New Groundlight client methods: - list_detector_pipelines(detector) -> List[MLPipeline] - list_priming_groups() -> List[PrimingGroup] - create_priming_group(name, source_ml_pipeline_id, canonical_query, disable_shadow_pipelines) -> PrimingGroup - get_priming_group(priming_group_id) -> PrimingGroup - delete_priming_group(priming_group_id) New pydantic models in generated/model.py: MLPipeline, PrimingGroup, PaginatedMLPipelineList, PaginatedPrimingGroupList. PrimingGroups let users seed new detectors with a pre-trained model binary so they skip the cold-start period. Detectors created with priming_group_id (already supported in create_detector) will start with the primed model active.
1 parent 588406e commit f25780c

File tree

2 files changed

+208
-0
lines changed

2 files changed

+208
-0
lines changed

generated/model.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,3 +566,56 @@ class PaginatedRuleList(BaseModel):
566566
next: Optional[AnyUrl] = Field(None, examples=["http://api.example.org/accounts/?page=4"])
567567
previous: Optional[AnyUrl] = Field(None, examples=["http://api.example.org/accounts/?page=2"])
568568
results: List[Rule]
569+
570+
571+
class MLPipeline(BaseModel):
572+
"""
573+
An ML pipeline attached to a detector. Contains the pipeline configuration and model binary key.
574+
"""
575+
576+
id: str = Field(..., description="A unique ID for this pipeline.")
577+
pipeline_config: Optional[str] = Field(None, description="Pipeline configuration string.")
578+
cached_vizlogic_key: Optional[str] = Field(None, description="S3 key of the trained model binary.")
579+
is_active_pipeline: bool = Field(False, description="Whether this is the active (production) pipeline.")
580+
is_edge_pipeline: bool = Field(False, description="Whether this is an edge pipeline.")
581+
is_unclear_pipeline: bool = Field(False, description="Whether this is an unclear-handling pipeline.")
582+
is_oodd_pipeline: bool = Field(False, description="Whether this is an out-of-distribution detection pipeline.")
583+
is_enabled: bool = Field(True, description="Whether this pipeline is enabled.")
584+
created_at: Optional[datetime] = None
585+
trained_at: Optional[datetime] = None
586+
587+
588+
class PaginatedMLPipelineList(BaseModel):
589+
count: int = Field(..., examples=[123])
590+
next: Optional[AnyUrl] = Field(None, examples=["http://api.example.org/accounts/?page=4"])
591+
previous: Optional[AnyUrl] = Field(None, examples=["http://api.example.org/accounts/?page=2"])
592+
results: List[MLPipeline]
593+
594+
595+
class PrimingGroup(BaseModel):
596+
"""
597+
A PrimingGroup seeds new detectors with a pre-trained model binary so they start with a head start.
598+
"""
599+
600+
id: str = Field(..., description="A unique ID for this priming group.")
601+
name: str = Field(..., description="A short, descriptive name for the priming group.")
602+
canonical_query: Optional[str] = Field(None, description="Optional canonical query describing this priming group.")
603+
active_pipeline_config: Optional[str] = Field(None, description="Pipeline config used by detectors in this group.")
604+
active_pipeline_base_mlbinary_key: Optional[str] = Field(
605+
None, description="S3 key of the model binary that seeds new detectors in this group."
606+
)
607+
disable_shadow_pipelines: bool = Field(
608+
False,
609+
description=(
610+
"If True, new detectors in this group will not receive default shadow pipelines, "
611+
"guaranteeing the primed model stays active."
612+
),
613+
)
614+
created_at: Optional[datetime] = None
615+
616+
617+
class PaginatedPrimingGroupList(BaseModel):
618+
count: int = Field(..., examples=[123])
619+
next: Optional[AnyUrl] = Field(None, examples=["http://api.example.org/accounts/?page=4"])
620+
previous: Optional[AnyUrl] = Field(None, examples=["http://api.example.org/accounts/?page=2"])
621+
results: List[PrimingGroup]

src/groundlight/client.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from io import BufferedReader, BytesIO
88
from typing import Any, Callable, List, Optional, Tuple, Union
99

10+
import requests
11+
1012
from groundlight_openapi_client import Configuration
1113
from groundlight_openapi_client.api.detector_groups_api import DetectorGroupsApi
1214
from groundlight_openapi_client.api.detectors_api import DetectorsApi
@@ -33,9 +35,11 @@
3335
Detector,
3436
DetectorGroup,
3537
ImageQuery,
38+
MLPipeline,
3639
ModeEnum,
3740
PaginatedDetectorList,
3841
PaginatedImageQueryList,
42+
PrimingGroup,
3943
)
4044
from urllib3.exceptions import InsecureRequestWarning
4145
from urllib3.util.retry import Retry
@@ -1852,3 +1856,154 @@ def create_bounding_box_detector( # noqa: PLR0913 # pylint: disable=too-many-ar
18521856
detector_creation_input.mode_configuration = mode_config
18531857
obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
18541858
return Detector.parse_obj(obj.to_dict())
1859+
1860+
# ---------------------------------------------------------------------------
1861+
# ML Pipeline methods
1862+
# ---------------------------------------------------------------------------
1863+
1864+
def list_detector_pipelines(self, detector: Union[str, Detector]) -> List[MLPipeline]:
1865+
"""
1866+
Lists all ML pipelines associated with a given detector.
1867+
1868+
Each detector can have multiple pipelines (active, edge, shadow, etc.). This method returns
1869+
all of them, which is useful when selecting a source pipeline to seed a new PrimingGroup.
1870+
1871+
**Example usage**::
1872+
1873+
gl = Groundlight()
1874+
detector = gl.get_detector("det_abc123")
1875+
pipelines = gl.list_detector_pipelines(detector)
1876+
for p in pipelines:
1877+
if p.is_active_pipeline:
1878+
print(f"Active pipeline: {p.id}, config={p.pipeline_config}")
1879+
1880+
:param detector: A Detector object or detector ID string.
1881+
:return: A list of MLPipeline objects for this detector.
1882+
"""
1883+
detector_id = detector.id if isinstance(detector, Detector) else detector
1884+
url = f"{self.api_client.configuration.host}/v1/detectors/{detector_id}/pipelines"
1885+
response = requests.get(url, headers=self.api_client._headers(), verify=self.api_client.configuration.verify_ssl)
1886+
if response.status_code == 404:
1887+
raise NotFoundError(f"Detector '{detector_id}' not found.")
1888+
response.raise_for_status()
1889+
data = response.json()
1890+
return [MLPipeline(**item) for item in data.get("results", [])]
1891+
1892+
# ---------------------------------------------------------------------------
1893+
# PrimingGroup methods
1894+
# ---------------------------------------------------------------------------
1895+
1896+
def list_priming_groups(self) -> List[PrimingGroup]:
1897+
"""
1898+
Lists all PrimingGroups owned by the authenticated user's account.
1899+
1900+
PrimingGroups let you seed new detectors with a pre-trained model so they start with a
1901+
meaningful head start instead of a blank slate.
1902+
1903+
**Example usage**::
1904+
1905+
gl = Groundlight()
1906+
groups = gl.list_priming_groups()
1907+
for g in groups:
1908+
print(f"{g.name}: {g.id}")
1909+
1910+
:return: A list of PrimingGroup objects.
1911+
"""
1912+
url = f"{self.api_client.configuration.host}/v1/priming-groups"
1913+
response = requests.get(url, headers=self.api_client._headers(), verify=self.api_client.configuration.verify_ssl)
1914+
response.raise_for_status()
1915+
data = response.json()
1916+
return [PrimingGroup(**item) for item in data.get("results", [])]
1917+
1918+
def create_priming_group(
1919+
self,
1920+
name: str,
1921+
source_ml_pipeline_id: str,
1922+
canonical_query: Optional[str] = None,
1923+
disable_shadow_pipelines: bool = False,
1924+
) -> PrimingGroup:
1925+
"""
1926+
Creates a new PrimingGroup seeded from an existing ML pipeline.
1927+
1928+
The trained model binary from the source pipeline is copied into the new PrimingGroup.
1929+
Detectors subsequently created with this PrimingGroup's ID will start with that model
1930+
already loaded, bypassing the cold-start period.
1931+
1932+
**Example usage**::
1933+
1934+
gl = Groundlight()
1935+
detector = gl.get_detector("det_abc123")
1936+
pipelines = gl.list_detector_pipelines(detector)
1937+
active = next(p for p in pipelines if p.is_active_pipeline)
1938+
1939+
priming_group = gl.create_priming_group(
1940+
name="door-detector-primer",
1941+
source_ml_pipeline_id=active.id,
1942+
canonical_query="Is the door open?",
1943+
disable_shadow_pipelines=True,
1944+
)
1945+
print(f"Created priming group: {priming_group.id}")
1946+
1947+
:param name: A short, descriptive name for the priming group.
1948+
:param source_ml_pipeline_id: The ID of an MLPipeline whose trained model will seed this group.
1949+
The pipeline must belong to a detector in your account.
1950+
:param canonical_query: An optional description of the visual question this group answers.
1951+
:param disable_shadow_pipelines: If True, detectors created in this group will not receive
1952+
default shadow pipelines, ensuring the primed model stays active.
1953+
:return: The created PrimingGroup object.
1954+
"""
1955+
url = f"{self.api_client.configuration.host}/v1/priming-groups"
1956+
payload: dict = {
1957+
"name": name,
1958+
"source_ml_pipeline_id": source_ml_pipeline_id,
1959+
"disable_shadow_pipelines": disable_shadow_pipelines,
1960+
}
1961+
if canonical_query is not None:
1962+
payload["canonical_query"] = canonical_query
1963+
response = requests.post(
1964+
url, json=payload, headers=self.api_client._headers(), verify=self.api_client.configuration.verify_ssl
1965+
)
1966+
response.raise_for_status()
1967+
return PrimingGroup(**response.json())
1968+
1969+
def get_priming_group(self, priming_group_id: str) -> PrimingGroup:
1970+
"""
1971+
Retrieves a PrimingGroup by ID.
1972+
1973+
**Example usage**::
1974+
1975+
gl = Groundlight()
1976+
pg = gl.get_priming_group("pgp_abc123")
1977+
print(f"Priming group name: {pg.name}")
1978+
1979+
:param priming_group_id: The ID of the PrimingGroup to retrieve.
1980+
:return: The PrimingGroup object.
1981+
"""
1982+
url = f"{self.api_client.configuration.host}/v1/priming-groups/{priming_group_id}"
1983+
response = requests.get(url, headers=self.api_client._headers(), verify=self.api_client.configuration.verify_ssl)
1984+
if response.status_code == 404:
1985+
raise NotFoundError(f"PrimingGroup '{priming_group_id}' not found.")
1986+
response.raise_for_status()
1987+
return PrimingGroup(**response.json())
1988+
1989+
def delete_priming_group(self, priming_group_id: str) -> None:
1990+
"""
1991+
Deletes (soft-deletes) a PrimingGroup owned by the authenticated user.
1992+
1993+
This does not delete any detectors that were created using this priming group —
1994+
it only removes the priming group itself. Detectors already created remain unaffected.
1995+
1996+
**Example usage**::
1997+
1998+
gl = Groundlight()
1999+
gl.delete_priming_group("pgp_abc123")
2000+
2001+
:param priming_group_id: The ID of the PrimingGroup to delete.
2002+
"""
2003+
url = f"{self.api_client.configuration.host}/v1/priming-groups/{priming_group_id}"
2004+
response = requests.delete(
2005+
url, headers=self.api_client._headers(), verify=self.api_client.configuration.verify_ssl
2006+
)
2007+
if response.status_code == 404:
2008+
raise NotFoundError(f"PrimingGroup '{priming_group_id}' not found.")
2009+
response.raise_for_status()

0 commit comments

Comments
 (0)