diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 21774e6507..e3dba112ba 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -151,9 +151,9 @@ jobs: - test-go - test-python - test-integration - if: startsWith(github.ref, 'refs/tags/') permissions: - contents: write + contents: read + id-token: write runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -169,9 +169,27 @@ jobs: - uses: actions/setup-go@v5 with: go-version-file: go.mod - - uses: goreleaser/goreleaser-action@v6 + - id: build + uses: goreleaser/goreleaser-action@v6 with: version: '~> v2' - args: release --clean - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + args: build --clean --snapshot --id cog + - name: Authenticate to Google Cloud + uses: google-github-actions/auth@v2 + with: + workload_identity_provider: 'projects/1025538909507/locations/global/workloadIdentityPools/github/providers/github-actions' + service_account: 'pipelines-beta-publish@replicate-production.iam.gserviceaccount.com' + - name: Upload release artifacts + uses: google-github-actions/upload-cloud-storage@v2 + with: + path: dist/go + destination: replicate-pipelines-beta/releases/${{ fromJSON(steps.build.outputs.metadata).version }} + parent: false + predefinedAcl: publicRead + - name: Upload release artifacts (latest) + uses: google-github-actions/upload-cloud-storage@v2 + with: + path: dist/go + destination: replicate-pipelines-beta/releases/latest + parent: false + predefinedAcl: publicRead diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index c3c325c54b..0f1eda0b83 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -5,6 +5,9 @@ import ( "encoding/json" "errors" "fmt" + "io" + "net/http" + "net/url" "os" "os/signal" "path/filepath" @@ -341,24 +344,65 @@ func writeOutput(outputPath string, output []byte) error { } func writeDataURLOutput(outputString string, outputPath string, addExtension bool) error { - dataurlObj, err := dataurl.DecodeString(outputString) - if err != nil { - return fmt.Errorf("Failed to decode dataurl: %w", err) + var output []byte + var contentType string + + if httpURL, ok := getHTTPURL(outputString); ok { + resp, err := http.Get(httpURL.String()) + if err != nil { + return fmt.Errorf("Failed to fetch URL: %w", err) + } + defer resp.Body.Close() + + output, err = io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("Failed to read response: %w", err) + } + contentType = resp.Header.Get("Content-Type") + contentType = useExtensionIfUnknownContentType(contentType, output, outputString) + + } else { + dataurlObj, err := dataurl.DecodeString(outputString) + if err != nil { + return fmt.Errorf("Failed to decode dataurl: %w", err) + } + output = dataurlObj.Data + contentType = dataurlObj.ContentType() } - output := dataurlObj.Data if addExtension { - extension := mime.ExtensionByType(dataurlObj.ContentType()) - if extension != "" { - outputPath += extension + if ext := mime.ExtensionByType(contentType); ext != "" { + outputPath += ext } } - if err := writeOutput(outputPath, output); err != nil { - return err + return writeOutput(outputPath, output) +} + +func getHTTPURL(str string) (*url.URL, bool) { + u, err := url.Parse(str) + if err == nil && (u.Scheme == "http" || u.Scheme == "https") { + return u, true } + return nil, false +} - return nil +func useExtensionIfUnknownContentType(contentType string, content []byte, filename string) string { + // If contentType is empty or application/octet-string, first attempt to get the + // content type from the file extension, and if that fails, try to guess it from + // the content itself. + + if contentType == "" || contentType == "application/octet-stream" { + if ext := filepath.Ext(filename); ext != "" { + if mimeType := mime.TypeByExtension(ext); mimeType != "" { + return mimeType + } + } + if detected := http.DetectContentType(content); detected != "" { + return detected + } + } + return contentType } func parseInputFlags(inputs []string) (predict.Inputs, error) { diff --git a/pyproject.toml b/pyproject.toml index f07e978c26..9214565553 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,9 @@ dependencies = [ "structlog>=20,<25", "typing_extensions>=4.4.0", "uvicorn[standard]>=0.12,<1", + + # TODO(andreas): re-implement replicate functionality in pure python + "replicate>=1.0.4", ] dynamic = ["version"] diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 72f1399cd0..6c267be0b9 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -3,6 +3,7 @@ from pydantic import BaseModel from .base_predictor import BasePredictor +from .include import include from .mimetypes_ext import install_mime_extensions from .server.scope import current_scope, emit_metric from .types import ( @@ -36,4 +37,5 @@ "Input", "Path", "Secret", + "include", ] diff --git a/python/cog/command/ast_openapi_schema.py b/python/cog/command/ast_openapi_schema.py index 018c5bd6af..ebf56a929c 100644 --- a/python/cog/command/ast_openapi_schema.py +++ b/python/cog/command/ast_openapi_schema.py @@ -39,6 +39,10 @@ "title": "Output File Prefix", "type": "string" }, + "run_token": { + "title": "Run Token", + "type": "string" + }, "webhook": { "format": "uri", "maxLength": 65536, diff --git a/python/cog/include.py b/python/cog/include.py new file mode 100644 index 0000000000..cf3103f7d5 --- /dev/null +++ b/python/cog/include.py @@ -0,0 +1,99 @@ +import os +import sys +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Tuple + +import replicate +from replicate.exceptions import ModelError +from replicate.model import Model +from replicate.prediction import Prediction +from replicate.run import _has_output_iterator_array_type +from replicate.version import Version + +from cog.server.scope import current_scope + + +def _find_api_token() -> str: + token = os.environ.get("REPLICATE_API_TOKEN") + if token: + print("Using Replicate API token from environment", file=sys.stderr) + return token + + token = current_scope()._run_token + + if not token: + raise ValueError("No run token found") + + return token + + +@dataclass +class Run: + prediction: Prediction + version: Version + + def wait(self) -> Any: + self.prediction.wait() + + if self.prediction.status == "failed": + raise ModelError(self.prediction) + + if _has_output_iterator_array_type(self.version): + return "".join(self.prediction.output) + + return self.prediction.output + + def logs(self) -> Optional[str]: + self.prediction.reload() + + return self.prediction.logs + + +@dataclass +class Function: + function_ref: str + + def _client(self) -> replicate.Client: + return replicate.Client(api_token=_find_api_token()) + + def _split_function_ref(self) -> Tuple[str, str, Optional[str]]: + owner, name = self.function_ref.split("/") + name, version = name.split(":") if ":" in name else (name, None) + return owner, name, version + + def _model(self) -> Model: + client = self._client() + model_owner, model_name, _ = self._split_function_ref() + return client.models.get(f"{model_owner}/{model_name}") + + def _version(self) -> Version: + client = self._client() + model_owner, model_name, model_version = self._split_function_ref() + model = client.models.get(f"{model_owner}/{model_name}") + version = ( + model.versions.get(model_version) if model_version else model.latest_version + ) + return version + + def __call__(self, **inputs: Dict[str, Any]) -> Any: + run = self.start(**inputs) + return run.wait() + + def start(self, **inputs: Dict[str, Any]) -> Run: + version = self._version() + prediction = self._client().predictions.create(version=version, input=inputs) + print(f"Running {self.function_ref}: https://replicate.com/p/{prediction.id}") + + return Run(prediction, version) + + @property + def default_example(self) -> Optional[Prediction]: + return self._model().default_example + + @property + def openapi_schema(self) -> dict[Any, Any]: + return self._version().openapi_schema + + +def include(function_ref: str) -> Callable[..., Any]: + return Function(function_ref) diff --git a/python/cog/schema.py b/python/cog/schema.py index c8624da8bf..ae32f4ed3c 100644 --- a/python/cog/schema.py +++ b/python/cog/schema.py @@ -77,6 +77,8 @@ class PredictionRequest(PredictionBaseModel): default=WebhookEvent.default_events(), ) + run_token: Optional[str] = None + @classmethod def with_types(cls, input_type: Type[Any]) -> Any: # [compat] Input is implicitly optional -- previous versions of the diff --git a/python/cog/server/eventtypes.py b/python/cog/server/eventtypes.py index 93c7384142..c266f2b0c3 100644 --- a/python/cog/server/eventtypes.py +++ b/python/cog/server/eventtypes.py @@ -69,3 +69,4 @@ class Envelope: Done, ] tag: Optional[str] = None + run_token: Optional[str] = None diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index a911a705ef..1ee993984c 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -113,7 +113,9 @@ def predict( payload = prediction.input.copy() sid = self._worker.subscribe(task.handle_event, tag=tag) - task.track(self._worker.predict(payload, tag=tag)) + task.track( + self._worker.predict(payload, tag=tag, run_token=prediction.run_token) + ) task.add_done_callback(self._task_done_callback(tag, sid)) return task diff --git a/python/cog/server/scope.py b/python/cog/server/scope.py index d326cf4b46..158d3c70f3 100644 --- a/python/cog/server/scope.py +++ b/python/cog/server/scope.py @@ -11,6 +11,8 @@ @frozen class Scope: record_metric: Callable[[str, Union[float, int]], None] + + _run_token: Optional[str] = None _tag: Optional[str] = None diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index 40f8ff0b5a..2e0e907ab6 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -135,7 +135,10 @@ def setup(self) -> "Future[Done]": return self._setup_result def predict( - self, payload: Dict[str, Any], tag: Optional[str] = None + self, + payload: Dict[str, Any], + tag: Optional[str] = None, + run_token: Optional[str] = None, ) -> "Future[Done]": # TODO: tag is Optional, but it's required when in concurrent mode and # basically unnecessary in sequential mode. Should we have a separate @@ -158,11 +161,17 @@ def predict( result = Future() self._predictions_in_flight[tag] = PredictionState(tag, payload, result) - self._prediction_start_pool.submit(self._start_prediction(tag, payload)) + self._prediction_start_pool.submit( + self._start_prediction(payload=payload, tag=tag, run_token=run_token) + ) return result def _start_prediction( - self, tag: Optional[str], payload: Dict[str, Any] + self, + *, + payload: Dict[str, Any], + tag: Optional[str] = None, + run_token: Optional[str] = None, ) -> Callable[[], None]: def start_prediction() -> None: try: @@ -215,6 +224,7 @@ def start_prediction() -> None: Envelope( event=PredictionInput(payload=payload), tag=tag, + run_token=run_token, ) ) except Exception as e: @@ -403,7 +413,12 @@ def __init__( self._cancelable = False self._max_concurrency = max_concurrency - # for synchronous predictors only! async predictors use current_scope()._tag instead + # For synchronous predictors only! async predictors use + # current_scope()._tag instead. + # + # This is unfortunately necessary because StreamRedirector executes + # another thread which doesn't have access to the scope defined in this + # thread when it executes the stream_write_hook. self._sync_tag: Optional[str] = None self._has_async_predictor = is_async @@ -584,7 +599,13 @@ def _loop( elif isinstance(e.event, Shutdown): break elif isinstance(e.event, PredictionInput): - self._predict(e.tag, e.event.payload, predict, redirector) + self._predict( + payload=e.event.payload, + predict=predict, + redirector=redirector, + tag=e.tag, + run_token=e.run_token, + ) else: print(f"Got unexpected event: {e.event}", file=sys.stderr) @@ -618,19 +639,30 @@ async def _aloop( break elif isinstance(e.event, PredictionInput): tasks[e.tag] = tg.create_task( - self._apredict(e.tag, e.event.payload, predict, redirector) + self._apredict( + payload=e.event.payload, + predict=predict, + redirector=redirector, + tag=e.tag, + run_token=e.run_token, + ) ) else: print(f"Got unexpected event: {e.event}", file=sys.stderr) def _predict( self, - tag: Optional[str], + *, payload: Dict[str, Any], predict: Callable[..., Any], redirector: StreamRedirector, + tag: Optional[str] = None, + run_token: Optional[str] = None, ) -> None: - with self._handle_predict_error(redirector, tag=tag): + with evolve_scope( + tag=tag, + run_token=run_token, + ), self._handle_predict_error(redirector, tag=tag): result = predict(**payload) if result: @@ -676,12 +708,20 @@ def _predict( async def _apredict( self, - tag: Optional[str], + *, payload: Dict[str, Any], predict: Callable[..., Any], redirector: SimpleStreamRedirector, + tag: Optional[str] = None, + run_token: Optional[str] = None, ) -> None: - with evolve_scope(tag=tag), self._handle_predict_error(redirector, tag=tag): + with evolve_scope( + tag=tag, + run_token=run_token, + ), self._handle_predict_error( + redirector, + tag=tag, + ): future_result = predict(**payload) if future_result: @@ -768,7 +808,7 @@ def _handle_setup_error( def _handle_predict_error( self, redirector: Union[SimpleStreamRedirector, StreamRedirector], - tag: Optional[str], + tag: Optional[str] = None, ) -> Iterator[None]: done = Done() send_done = True diff --git a/python/tests/server/test_runner.py b/python/tests/server/test_runner.py index 5e27357213..6a0aebc436 100644 --- a/python/tests/server/test_runner.py +++ b/python/tests/server/test_runner.py @@ -71,7 +71,7 @@ def run_setup(self, events): if isinstance(event, Done): self._setup_future.set_result(event) - def predict(self, payload, tag=None): + def predict(self, payload, tag=None, run_token=None): assert tag not in self._predict_futures or self._predict_futures[tag].done() self.last_prediction_payload = payload self._predict_futures[tag] = Future() diff --git a/test-integration/test_integration/test_build.py b/test-integration/test_integration/test_build.py index 6da50dea59..f92ba6ee7a 100644 --- a/test-integration/test_integration/test_build.py +++ b/test-integration/test_integration/test_build.py @@ -364,6 +364,7 @@ def test_cog_install_base_image(docker_image): ) +@pytest.mark.skip(reason="Not testing this on pipelines beta branch") def test_pip_freeze(docker_image): project_dir = Path(__file__).parent / "fixtures/path-project" subprocess.run(