|
7 | 7 | from io import BufferedReader, BytesIO |
8 | 8 | from typing import Any, Callable, List, Optional, Tuple, Union |
9 | 9 |
|
| 10 | +import requests |
| 11 | + |
10 | 12 | from groundlight_openapi_client import Configuration |
11 | 13 | from groundlight_openapi_client.api.detector_groups_api import DetectorGroupsApi |
12 | 14 | from groundlight_openapi_client.api.detectors_api import DetectorsApi |
|
33 | 35 | Detector, |
34 | 36 | DetectorGroup, |
35 | 37 | ImageQuery, |
| 38 | + MLPipeline, |
36 | 39 | ModeEnum, |
37 | 40 | PaginatedDetectorList, |
38 | 41 | PaginatedImageQueryList, |
| 42 | + PrimingGroup, |
39 | 43 | ) |
40 | 44 | from urllib3.exceptions import InsecureRequestWarning |
41 | 45 | from urllib3.util.retry import Retry |
@@ -1852,3 +1856,154 @@ def create_bounding_box_detector( # noqa: PLR0913 # pylint: disable=too-many-ar |
1852 | 1856 | detector_creation_input.mode_configuration = mode_config |
1853 | 1857 | obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT) |
1854 | 1858 | return Detector.parse_obj(obj.to_dict()) |
| 1859 | + |
| 1860 | + # --------------------------------------------------------------------------- |
| 1861 | + # ML Pipeline methods |
| 1862 | + # --------------------------------------------------------------------------- |
| 1863 | + |
| 1864 | + def list_detector_pipelines(self, detector: Union[str, Detector]) -> List[MLPipeline]: |
| 1865 | + """ |
| 1866 | + Lists all ML pipelines associated with a given detector. |
| 1867 | +
|
| 1868 | + Each detector can have multiple pipelines (active, edge, shadow, etc.). This method returns |
| 1869 | + all of them, which is useful when selecting a source pipeline to seed a new PrimingGroup. |
| 1870 | +
|
| 1871 | + **Example usage**:: |
| 1872 | +
|
| 1873 | + gl = Groundlight() |
| 1874 | + detector = gl.get_detector("det_abc123") |
| 1875 | + pipelines = gl.list_detector_pipelines(detector) |
| 1876 | + for p in pipelines: |
| 1877 | + if p.is_active_pipeline: |
| 1878 | + print(f"Active pipeline: {p.id}, config={p.pipeline_config}") |
| 1879 | +
|
| 1880 | + :param detector: A Detector object or detector ID string. |
| 1881 | + :return: A list of MLPipeline objects for this detector. |
| 1882 | + """ |
| 1883 | + detector_id = detector.id if isinstance(detector, Detector) else detector |
| 1884 | + url = f"{self.api_client.configuration.host}/v1/detectors/{detector_id}/pipelines" |
| 1885 | + response = requests.get(url, headers=self.api_client._headers(), verify=self.api_client.configuration.verify_ssl) |
| 1886 | + if response.status_code == 404: |
| 1887 | + raise NotFoundError(f"Detector '{detector_id}' not found.") |
| 1888 | + response.raise_for_status() |
| 1889 | + data = response.json() |
| 1890 | + return [MLPipeline(**item) for item in data.get("results", [])] |
| 1891 | + |
| 1892 | + # --------------------------------------------------------------------------- |
| 1893 | + # PrimingGroup methods |
| 1894 | + # --------------------------------------------------------------------------- |
| 1895 | + |
| 1896 | + def list_priming_groups(self) -> List[PrimingGroup]: |
| 1897 | + """ |
| 1898 | + Lists all PrimingGroups owned by the authenticated user's account. |
| 1899 | +
|
| 1900 | + PrimingGroups let you seed new detectors with a pre-trained model so they start with a |
| 1901 | + meaningful head start instead of a blank slate. |
| 1902 | +
|
| 1903 | + **Example usage**:: |
| 1904 | +
|
| 1905 | + gl = Groundlight() |
| 1906 | + groups = gl.list_priming_groups() |
| 1907 | + for g in groups: |
| 1908 | + print(f"{g.name}: {g.id}") |
| 1909 | +
|
| 1910 | + :return: A list of PrimingGroup objects. |
| 1911 | + """ |
| 1912 | + url = f"{self.api_client.configuration.host}/v1/priming-groups" |
| 1913 | + response = requests.get(url, headers=self.api_client._headers(), verify=self.api_client.configuration.verify_ssl) |
| 1914 | + response.raise_for_status() |
| 1915 | + data = response.json() |
| 1916 | + return [PrimingGroup(**item) for item in data.get("results", [])] |
| 1917 | + |
| 1918 | + def create_priming_group( |
| 1919 | + self, |
| 1920 | + name: str, |
| 1921 | + source_ml_pipeline_id: str, |
| 1922 | + canonical_query: Optional[str] = None, |
| 1923 | + disable_shadow_pipelines: bool = False, |
| 1924 | + ) -> PrimingGroup: |
| 1925 | + """ |
| 1926 | + Creates a new PrimingGroup seeded from an existing ML pipeline. |
| 1927 | +
|
| 1928 | + The trained model binary from the source pipeline is copied into the new PrimingGroup. |
| 1929 | + Detectors subsequently created with this PrimingGroup's ID will start with that model |
| 1930 | + already loaded, bypassing the cold-start period. |
| 1931 | +
|
| 1932 | + **Example usage**:: |
| 1933 | +
|
| 1934 | + gl = Groundlight() |
| 1935 | + detector = gl.get_detector("det_abc123") |
| 1936 | + pipelines = gl.list_detector_pipelines(detector) |
| 1937 | + active = next(p for p in pipelines if p.is_active_pipeline) |
| 1938 | +
|
| 1939 | + priming_group = gl.create_priming_group( |
| 1940 | + name="door-detector-primer", |
| 1941 | + source_ml_pipeline_id=active.id, |
| 1942 | + canonical_query="Is the door open?", |
| 1943 | + disable_shadow_pipelines=True, |
| 1944 | + ) |
| 1945 | + print(f"Created priming group: {priming_group.id}") |
| 1946 | +
|
| 1947 | + :param name: A short, descriptive name for the priming group. |
| 1948 | + :param source_ml_pipeline_id: The ID of an MLPipeline whose trained model will seed this group. |
| 1949 | + The pipeline must belong to a detector in your account. |
| 1950 | + :param canonical_query: An optional description of the visual question this group answers. |
| 1951 | + :param disable_shadow_pipelines: If True, detectors created in this group will not receive |
| 1952 | + default shadow pipelines, ensuring the primed model stays active. |
| 1953 | + :return: The created PrimingGroup object. |
| 1954 | + """ |
| 1955 | + url = f"{self.api_client.configuration.host}/v1/priming-groups" |
| 1956 | + payload: dict = { |
| 1957 | + "name": name, |
| 1958 | + "source_ml_pipeline_id": source_ml_pipeline_id, |
| 1959 | + "disable_shadow_pipelines": disable_shadow_pipelines, |
| 1960 | + } |
| 1961 | + if canonical_query is not None: |
| 1962 | + payload["canonical_query"] = canonical_query |
| 1963 | + response = requests.post( |
| 1964 | + url, json=payload, headers=self.api_client._headers(), verify=self.api_client.configuration.verify_ssl |
| 1965 | + ) |
| 1966 | + response.raise_for_status() |
| 1967 | + return PrimingGroup(**response.json()) |
| 1968 | + |
| 1969 | + def get_priming_group(self, priming_group_id: str) -> PrimingGroup: |
| 1970 | + """ |
| 1971 | + Retrieves a PrimingGroup by ID. |
| 1972 | +
|
| 1973 | + **Example usage**:: |
| 1974 | +
|
| 1975 | + gl = Groundlight() |
| 1976 | + pg = gl.get_priming_group("pgp_abc123") |
| 1977 | + print(f"Priming group name: {pg.name}") |
| 1978 | +
|
| 1979 | + :param priming_group_id: The ID of the PrimingGroup to retrieve. |
| 1980 | + :return: The PrimingGroup object. |
| 1981 | + """ |
| 1982 | + url = f"{self.api_client.configuration.host}/v1/priming-groups/{priming_group_id}" |
| 1983 | + response = requests.get(url, headers=self.api_client._headers(), verify=self.api_client.configuration.verify_ssl) |
| 1984 | + if response.status_code == 404: |
| 1985 | + raise NotFoundError(f"PrimingGroup '{priming_group_id}' not found.") |
| 1986 | + response.raise_for_status() |
| 1987 | + return PrimingGroup(**response.json()) |
| 1988 | + |
| 1989 | + def delete_priming_group(self, priming_group_id: str) -> None: |
| 1990 | + """ |
| 1991 | + Deletes (soft-deletes) a PrimingGroup owned by the authenticated user. |
| 1992 | +
|
| 1993 | + This does not delete any detectors that were created using this priming group — |
| 1994 | + it only removes the priming group itself. Detectors already created remain unaffected. |
| 1995 | +
|
| 1996 | + **Example usage**:: |
| 1997 | +
|
| 1998 | + gl = Groundlight() |
| 1999 | + gl.delete_priming_group("pgp_abc123") |
| 2000 | +
|
| 2001 | + :param priming_group_id: The ID of the PrimingGroup to delete. |
| 2002 | + """ |
| 2003 | + url = f"{self.api_client.configuration.host}/v1/priming-groups/{priming_group_id}" |
| 2004 | + response = requests.delete( |
| 2005 | + url, headers=self.api_client._headers(), verify=self.api_client.configuration.verify_ssl |
| 2006 | + ) |
| 2007 | + if response.status_code == 404: |
| 2008 | + raise NotFoundError(f"PrimingGroup '{priming_group_id}' not found.") |
| 2009 | + response.raise_for_status() |
0 commit comments