Skip to content

Commit

Permalink
Merge branch 'main' into feature/inference_cli_with_workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
PawelPeczek-Roboflow authored Nov 28, 2024
2 parents d9f4359 + ff98987 commit fdac101
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
3 changes: 2 additions & 1 deletion inference/core/entities/requests/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@


class WorkflowInferenceRequest(BaseModel):
api_key: str = Field(
api_key: Optional[str] = Field(
default=None,
description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
)
inputs: Dict[str, Any] = Field(
Expand Down
33 changes: 22 additions & 11 deletions inference/core/roboflow_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,11 +393,15 @@ def get_roboflow_labeling_jobs(


def get_workflow_cache_file(
workspace_id: WorkspaceID, workflow_id: str, api_key: str
workspace_id: WorkspaceID, workflow_id: str, api_key: Optional[str]
) -> str:
sanitized_workspace_id = sanitize_path_segment(workspace_id)
sanitized_workflow_id = sanitize_path_segment(workflow_id)
api_key_hash = hashlib.md5(api_key.encode("utf-8")).hexdigest()
api_key_hash = (
hashlib.md5(api_key.encode("utf-8")).hexdigest()
if api_key is not None
else "None"
)
prefix = os.path.abspath(os.path.join(MODEL_CACHE_DIR, "workflow"))
result = os.path.abspath(
os.path.join(
Expand All @@ -414,7 +418,7 @@ def get_workflow_cache_file(


def cache_workflow_response(
workspace_id: WorkspaceID, workflow_id: str, api_key: str, response: dict
workspace_id: WorkspaceID, workflow_id: str, api_key: Optional[str], response: dict
):
workflow_cache_file = get_workflow_cache_file(
workspace_id=workspace_id,
Expand All @@ -431,7 +435,7 @@ def cache_workflow_response(
def delete_cached_workflow_response_if_exists(
workspace_id: WorkspaceID,
workflow_id: str,
api_key: str,
api_key: Optional[str],
) -> None:
workflow_cache_file = get_workflow_cache_file(
workspace_id=workspace_id,
Expand All @@ -445,7 +449,7 @@ def delete_cached_workflow_response_if_exists(
def load_cached_workflow_response(
workspace_id: WorkspaceID,
workflow_id: str,
api_key: str,
api_key: Optional[str],
) -> Optional[dict]:
workflow_cache_file = get_workflow_cache_file(
workspace_id=workspace_id,
Expand All @@ -467,7 +471,7 @@ def load_cached_workflow_response(

@wrap_roboflow_api_errors()
def get_workflow_specification(
api_key: str,
api_key: Optional[str],
workspace_id: WorkspaceID,
workflow_id: str,
use_cache: bool = True,
Expand All @@ -483,9 +487,12 @@ def get_workflow_specification(
)
if cached_entry:
return cached_entry
params = []
if api_key is not None:
params.append(("api_key", api_key))
api_url = _add_params_to_url(
url=f"{API_BASE_URL}/{workspace_id}/workflows/{workflow_id}",
params=[("api_key", api_key)],
params=params,
)
try:
response = _get_from_url(url=api_url)
Expand Down Expand Up @@ -533,7 +540,7 @@ def get_workflow_specification(


def _retrieve_workflow_specification_from_ephemeral_cache(
api_key: str,
api_key: Optional[str],
workspace_id: WorkspaceID,
workflow_id: str,
ephemeral_cache: BaseCache,
Expand All @@ -547,7 +554,7 @@ def _retrieve_workflow_specification_from_ephemeral_cache(


def _cache_workflow_specification_in_ephemeral_cache(
api_key: str,
api_key: Optional[str],
workspace_id: WorkspaceID,
workflow_id: str,
specification: dict,
Expand All @@ -566,11 +573,15 @@ def _cache_workflow_specification_in_ephemeral_cache(


def _prepare_workflow_response_cache_key(
api_key: str,
api_key: Optional[str],
workspace_id: WorkspaceID,
workflow_id: str,
) -> str:
api_key_hash = hashlib.md5(api_key.encode("utf-8")).hexdigest()
api_key_hash = (
hashlib.md5(api_key.encode("utf-8")).hexdigest()
if api_key is not None
else "None"
)
return f"workflow_definition:{workspace_id}:{workflow_id}:{api_key_hash}"


Expand Down

0 comments on commit fdac101

Please sign in to comment.