Skip to content

Commit

Permalink
Compute Client: allow passing in catalog client to retrieve results (…
Browse files Browse the repository at this point in the history
…#12167)

GitOrigin-RevId: cd906ed6dd3df467368e72294d69478a5621b2e6
  • Loading branch information
tkrause authored and Descartes Labs Build committed Sep 18, 2023
1 parent 2b0986f commit a0fc015
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 54 deletions.
9 changes: 8 additions & 1 deletion descarteslabs/core/common/client/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import copy
import inspect
import json
from typing import TYPE_CHECKING, Generic, Iterator, List, TypeVar, Union

Expand Down Expand Up @@ -88,10 +89,16 @@ def __iter__(self: AnySearch) -> Iterator[T]:
>>> search = Function.search().filter(Function.status == "success")
>>> list(search) # doctest: +SKIP
"""
accepts_client = (
"client" in inspect.signature(self._document.__init__).parameters
)
documents = self._client.iter_pages(self._url, params=self._serialize())

for document in documents:
yield self._document(**document, saved=True)
if accepts_client:
yield self._document(**document, client=self._client, saved=True)
else:
yield self._document(**document, saved=True)

def collect(self: AnySearch, **kwargs) -> Collection[T]:
"""
Expand Down
7 changes: 6 additions & 1 deletion descarteslabs/core/compute/compute_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,23 @@
from descarteslabs.auth import Auth
from descarteslabs.config import get_settings

from ..catalog import CatalogClient
from ..client.services.service import ApiService
from ..common.http.service import DefaultClientMixin


class ComputeClient(ApiService, DefaultClientMixin):
def __init__(self, url=None, auth=None, retries=None):
def __init__(self, url=None, auth=None, catalog_client=None, retries=None):
if auth is None:
auth = Auth.get_default_auth()

if catalog_client is None:
catalog_client = CatalogClient(auth=auth)

if url is None:
url = get_settings().compute_url

self.catalog_client = catalog_client
super().__init__(url, auth=auth, retries=retries)

def iter_log_lines(self, url: str, timestamps: bool = True) -> Iterator[str]:
Expand Down
54 changes: 28 additions & 26 deletions descarteslabs/core/compute/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def __init__(
maximum_concurrency: int = None,
timeout: int = None,
retry_count: int = None,
client: ComputeClient = None,
**extra,
): # check to see if we need more validation here (type conversions)
"""
Expand Down Expand Up @@ -241,6 +242,7 @@ def __init__(
Job <job id>: "pending"
"""

self._client = client or ComputeClient.get_default_client()
self._function = function
self._requirements = requirements
self._include_data = include_data
Expand Down Expand Up @@ -674,13 +676,16 @@ def _data_globs_to_paths(self) -> List[str]:
return data_files

@classmethod
def get(cls, id: str, **params):
def get(cls, id: str, client: ComputeClient = None, **params):
"""Get Function by id.
Parameters
----------
id : str
Id of function to get.
client: ComputeClient, None
If set, the result will be retrieved using the configured client.
Otherwise, the default client will be used.
include : List[str], optional
List of additional attributes to include in the response.
Allowed values are:
Expand All @@ -693,12 +698,14 @@ def get(cls, id: str, **params):
>>> fn = Function.get(<func_id>)
<Function name="test_name" image=test_image cpus=1 memory=16 maximum_concurrency=5 timeout=3 retries=1
"""
client = ComputeClient.get_default_client()
client = client or ComputeClient.get_default_client()
response = client.session.get(f"/functions/{id}", params=params)
return cls(**response.json(), saved=True)
return cls(**response.json(), client=client, saved=True)

@classmethod
def list(cls, page_size: int = 100, **params) -> Search["Function"]:
def list(
cls, page_size: int = 100, client: ComputeClient = None, **params
) -> Search["Function"]:
"""Lists all Functions for a user.
If you would like to filter Functions, use :py:meth:`Function.search`.
Expand All @@ -707,6 +714,9 @@ def list(cls, page_size: int = 100, **params) -> Search["Function"]:
----------
page_size : int, default=100
Maximum number of results per page.
client: ComputeClient, None
If set, the result will be retrieved using the configured client.
Otherwise, the default client will be used.
include : List[str], optional
List of additional attributes to include in the response.
Allowed values are:
Expand All @@ -719,10 +729,10 @@ def list(cls, page_size: int = 100, **params) -> Search["Function"]:
>>> fn = Function.list()
"""
params = {"page_size": page_size, **params}
return cls.search().param(**params)
return cls.search(client=client).param(**params)

@classmethod
def search(cls) -> Search["Function"]:
def search(cls, client: ComputeClient = None) -> Search["Function"]:
"""Creates a search for Functions.
The search is lazy and will be executed when the search is iterated over or
Expand All @@ -740,10 +750,11 @@ def search(cls) -> Search["Function"]:
... )
Collection([Function <fn-id1>: building, Function <fn-id2>: awaiting_bundle])
"""
return Search(Function, ComputeClient.get_default_client(), url="/functions")
client = client or ComputeClient.get_default_client()
return Search(Function, client, url="/functions")

@classmethod
def update_credentials(cls):
def update_credentials(cls, client: ComputeClient = None):
"""Updates the credentials for the Functions and Jobs run by this user.
These credentials are used by other Descarteslabs services.
Expand All @@ -755,18 +766,17 @@ def update_credentials(cls):
-----
Credentials are automatically updated when a new Function is created.
"""
client = ComputeClient.get_default_client()
client = client or ComputeClient.get_default_client()
client.set_credentials()

@property
def jobs(self) -> JobSearch:
"""Returns all the Jobs for the Function."""
return Job.search().filter(Job.function_id == self.id)
return Job.search(client=self._client).filter(Job.function_id == self.id)

def build_log(self):
"""Retrieves the build log for the Function."""
client = ComputeClient.get_default_client()
response = client.session.get(f"/functions/{self.id}/log")
response = self._client.session.get(f"/functions/{self.id}/log")

print(gzip.decompress(response.content).decode())

Expand All @@ -775,12 +785,10 @@ def delete(self):
if self.state == DocumentState.NEW:
raise ValueError("Cannot delete a Function that has not been saved")

client = ComputeClient.get_default_client()

for job in self.jobs:
job.delete()

client.session.delete(f"/functions/{self.id}")
self._client.session.delete(f"/functions/{self.id}")
self._deleted = True

def save(self):
Expand Down Expand Up @@ -821,13 +829,11 @@ def save(self):
# Document already exists on the server without changes locally
return

client = ComputeClient.get_default_client()

if self.state == DocumentState.NEW:
self.update_credentials()

code_bundle_path = self._bundle()
response = client.session.post(
response = self._client.session.post(
"/functions", json=self.to_dict(exclude_none=True)
)
response_json = response.json()
Expand All @@ -843,10 +849,10 @@ def save(self):
s3_client.session.put(upload_url, data=code_bundle, headers=headers)

# Complete the upload with compute
response = client.session.post(f"/functions/{self.id}/bundle")
response = self._client.session.post(f"/functions/{self.id}/bundle")
self._load_from_remote(response.json())
elif self.state == DocumentState.MODIFIED:
response = client.session.patch(
response = self._client.session.patch(
f"/functions/{self.id}", json=self.to_dict(only_modified=True)
)
self._load_from_remote(response.json())
Expand Down Expand Up @@ -896,8 +902,6 @@ def map(
tags : List[str], optional
A list of tags to apply to all jobs submitted.
"""
client = ComputeClient.get_default_client()

# save in case the function doesn't exist yet
self.save()

Expand All @@ -918,7 +922,7 @@ def map(
if tags:
payload["tags"] = tags

response = client.session.post("/jobs/bulk", json=payload)
response = self._client.session.post("/jobs/bulk", json=payload)
return [Job(**data, saved=True) for data in response.json()]

def rerun(self):
Expand All @@ -927,14 +931,12 @@ def rerun(self):

def refresh(self):
"""Updates the Function instance with data from the server."""
client = ComputeClient.get_default_client()

if self.job_statistics:
params = {"include": ["job.statistics"]}
else:
params = {}

response = client.session.get(f"/functions/{self.id}", params=params)
response = self._client.session.get(f"/functions/{self.id}", params=params)
self._load_from_remote(response.json())

def iter_results(self, cast_type: Type[Serializable] = None):
Expand Down
Loading

0 comments on commit a0fc015

Please sign in to comment.