Skip to content

Commit 3ca6205

Browse files
philandstuffaron
andauthored
End-to-end support for concurrent async models (#2066)
This builds on the work in #2057 and wires it up end-to-end. We can now support async models with a max concurrency configured, and submit multiple predictions concurrently to them. We only support python 3.11 for async models; this is so that we can use asyncio.TaskGroup to keep track of multiple predictions in flight and ensure they all complete when shutting down. The cog http server was already async, but at one point it called wait() on a concurrent.futures.Future() which blocked the event loop and therefore prevented concurrent prediction requests (when not using prefer-async, which is how the tests run). I have updated this code to wait on asyncio.wrap_future(fut) instead which does not block the event loop. As part of this I have updated the training endpoints to also be asynchronous. We now have three places in the code which keep track of how many predictions are in flight: PredictionRunner, Worker and _ChildWorker all do their own bookkeeping. I'm not sure this is the best design but it works. The code is now an uneasy mix of threaded and asyncio code. This is evident in the usage of threading.Lock, which wouldn't be needed if we were 100% async (and I'm not sure if it's actually needed currently; I just added it to be safe). Co-authored-by: Aron Carroll <[email protected]>
1 parent e181041 commit 3ca6205

File tree

19 files changed

+502
-172
lines changed

19 files changed

+502
-172
lines changed

pkg/cli/build.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func buildCommand(cmd *cobra.Command, args []string) error {
6464
imageName = config.DockerImageName(projectDir)
6565
}
6666

67-
err = config.ValidateModelPythonVersion(cfg.Build.PythonVersion)
67+
err = config.ValidateModelPythonVersion(cfg)
6868
if err != nil {
6969
return err
7070
}

pkg/config/config.go

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@ var (
3030
// TODO(andreas): suggest valid torchvision versions (e.g. if the user wants to use 0.8.0, suggest 0.8.1)
3131

3232
const (
33-
MinimumMajorPythonVersion int = 3
34-
MinimumMinorPythonVersion int = 8
35-
MinimumMajorCudaVersion int = 11
33+
MinimumMajorPythonVersion int = 3
34+
MinimumMinorPythonVersion int = 8
35+
MinimumMinorPythonVersionForConcurrency int = 11
36+
MinimumMajorCudaVersion int = 11
3637
)
3738

3839
type RunItem struct {
@@ -58,16 +59,21 @@ type Build struct {
5859
pythonRequirementsContent []string
5960
}
6061

62+
type Concurrency struct {
63+
Max int `json:"max,omitempty" yaml:"max"`
64+
}
65+
6166
type Example struct {
6267
Input map[string]string `json:"input" yaml:"input"`
6368
Output string `json:"output" yaml:"output"`
6469
}
6570

6671
type Config struct {
67-
Build *Build `json:"build" yaml:"build"`
68-
Image string `json:"image,omitempty" yaml:"image"`
69-
Predict string `json:"predict,omitempty" yaml:"predict"`
70-
Train string `json:"train,omitempty" yaml:"train"`
72+
Build *Build `json:"build" yaml:"build"`
73+
Image string `json:"image,omitempty" yaml:"image"`
74+
Predict string `json:"predict,omitempty" yaml:"predict"`
75+
Train string `json:"train,omitempty" yaml:"train"`
76+
Concurrency *Concurrency `json:"concurrency,omitempty" yaml:"concurrency"`
7177
}
7278

7379
func DefaultConfig() *Config {
@@ -244,7 +250,9 @@ func splitPythonVersion(version string) (major int, minor int, err error) {
244250
return major, minor, nil
245251
}
246252

247-
func ValidateModelPythonVersion(version string) error {
253+
func ValidateModelPythonVersion(cfg *Config) error {
254+
version := cfg.Build.PythonVersion
255+
248256
// we check for minimum supported here
249257
major, minor, err := splitPythonVersion(version)
250258
if err != nil {
@@ -255,6 +263,10 @@ func ValidateModelPythonVersion(version string) error {
255263
return fmt.Errorf("minimum supported Python version is %d.%d. requested %s",
256264
MinimumMajorPythonVersion, MinimumMinorPythonVersion, version)
257265
}
266+
if cfg.Concurrency != nil && cfg.Concurrency.Max > 1 && minor < MinimumMinorPythonVersionForConcurrency {
267+
return fmt.Errorf("when concurrency.max is set, minimum supported Python version is %d.%d. requested %s",
268+
MinimumMajorPythonVersion, MinimumMinorPythonVersionForConcurrency, version)
269+
}
258270
return nil
259271
}
260272

pkg/config/config_test.go

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,47 +13,68 @@ import (
1313

1414
func TestValidateModelPythonVersion(t *testing.T) {
1515
testCases := []struct {
16-
name string
17-
input string
18-
expectedErr bool
16+
name string
17+
pythonVersion string
18+
concurrencyMax int
19+
expectedErr string
1920
}{
2021
{
21-
name: "ValidVersion",
22-
input: "3.12",
23-
expectedErr: false,
22+
name: "ValidVersion",
23+
pythonVersion: "3.12",
2424
},
2525
{
26-
name: "MinimumVersion",
27-
input: "3.8",
28-
expectedErr: false,
26+
name: "MinimumVersion",
27+
pythonVersion: "3.8",
2928
},
3029
{
31-
name: "FullyQualifiedVersion",
32-
input: "3.12.1",
33-
expectedErr: false,
30+
name: "MinimumVersionForConcurrency",
31+
pythonVersion: "3.11",
32+
concurrencyMax: 5,
3433
},
3534
{
36-
name: "InvalidFormat",
37-
input: "3-12",
38-
expectedErr: true,
35+
name: "TooOldForConcurrency",
36+
pythonVersion: "3.8",
37+
concurrencyMax: 5,
38+
expectedErr: "when concurrency.max is set, minimum supported Python version is 3.11. requested 3.8",
3939
},
4040
{
41-
name: "InvalidMissingMinor",
42-
input: "3",
43-
expectedErr: true,
41+
name: "FullyQualifiedVersion",
42+
pythonVersion: "3.12.1",
4443
},
4544
{
46-
name: "LessThanMinimum",
47-
input: "3.7",
48-
expectedErr: true,
45+
name: "InvalidFormat",
46+
pythonVersion: "3-12",
47+
expectedErr: "invalid Python version format: missing minor version in 3-12",
48+
},
49+
{
50+
name: "InvalidMissingMinor",
51+
pythonVersion: "3",
52+
expectedErr: "invalid Python version format: missing minor version in 3",
53+
},
54+
{
55+
name: "LessThanMinimum",
56+
pythonVersion: "3.7",
57+
expectedErr: "minimum supported Python version is 3.8. requested 3.7",
4958
},
5059
}
5160

5261
for _, tc := range testCases {
5362
t.Run(tc.name, func(t *testing.T) {
54-
err := ValidateModelPythonVersion(tc.input)
55-
if tc.expectedErr {
56-
require.Error(t, err)
63+
cfg := &Config{
64+
Build: &Build{
65+
PythonVersion: tc.pythonVersion,
66+
},
67+
}
68+
if tc.concurrencyMax != 0 {
69+
// the Concurrency key is optional, only populate it if
70+
// concurrencyMax is a non-default value
71+
cfg.Concurrency = &Concurrency{
72+
Max: tc.concurrencyMax,
73+
}
74+
}
75+
err := ValidateModelPythonVersion(cfg)
76+
if tc.expectedErr != "" {
77+
require.ErrorContains(t, err, tc.expectedErr)
5778
} else {
5879
require.NoError(t, err)
5980
}
@@ -649,17 +670,6 @@ func TestBlankBuild(t *testing.T) {
649670
require.Equal(t, false, config.Build.GPU)
650671
}
651672

652-
func TestModelPythonVersionValidation(t *testing.T) {
653-
err := ValidateModelPythonVersion("3.8")
654-
require.NoError(t, err)
655-
err = ValidateModelPythonVersion("3.8.1")
656-
require.NoError(t, err)
657-
err = ValidateModelPythonVersion("3.7")
658-
require.Equal(t, "minimum supported Python version is 3.8. requested 3.7", err.Error())
659-
err = ValidateModelPythonVersion("3.7.1")
660-
require.Equal(t, "minimum supported Python version is 3.8. requested 3.7.1", err.Error())
661-
}
662-
663673
func TestSplitPinnedPythonRequirement(t *testing.T) {
664674
testCases := []struct {
665675
input string

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ tests = [
4242
"numpy",
4343
"pillow",
4444
"pytest",
45+
"pytest-asyncio",
4546
"pytest-httpserver",
4647
"pytest-timeout",
4748
"pytest-xdist",
@@ -70,6 +71,9 @@ reportUnusedExpression = "warning"
7071
[tool.pyright.defineConstant]
7172
PYDANTIC_V2 = true
7273

74+
[tool.pytest.ini_options]
75+
asyncio_default_fixture_loop_scope = "function"
76+
7377
[tool.setuptools]
7478
include-package-data = false
7579

python/cog/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .mimetypes_ext import install_mime_extensions
77
from .server.scope import current_scope, emit_metric
88
from .types import (
9+
AsyncConcatenateIterator,
910
ConcatenateIterator,
1011
ExperimentalFeatureWarning,
1112
File,
@@ -26,6 +27,7 @@
2627
"__version__",
2728
"current_scope",
2829
"emit_metric",
30+
"AsyncConcatenateIterator",
2931
"BaseModel",
3032
"BasePredictor",
3133
"ConcatenateIterator",

python/cog/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
COG_PREDICT_CODE_STRIP_ENV_VAR = "COG_PREDICT_CODE_STRIP"
3434
COG_TRAIN_CODE_STRIP_ENV_VAR = "COG_TRAIN_CODE_STRIP"
3535
COG_GPU_ENV_VAR = "COG_GPU"
36+
COG_MAX_CONCURRENCY_ENV_VAR = "COG_MAX_CONCURRENCY"
3637
PREDICT_METHOD_NAME = "predict"
3738
TRAIN_METHOD_NAME = "train"
3839

@@ -101,6 +102,12 @@ def requires_gpu(self) -> bool:
101102
"""Whether this cog requires the use of a GPU."""
102103
return bool(self._cog_config.get("build", {}).get("gpu", False))
103104

105+
@property
106+
@env_property(COG_MAX_CONCURRENCY_ENV_VAR)
107+
def max_concurrency(self) -> int:
108+
"""The maximum concurrency of predictions supported by this model. Defaults to 1."""
109+
return int(self._cog_config.get("concurrency", {}).get("max", 1))
110+
104111
def _predictor_code(
105112
self,
106113
module_path: str,

python/cog/server/helpers.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
callback: Callable[[str, str], None],
3434
tee: bool = False,
3535
) -> None:
36-
super().__init__(buffer, line_buffering=True)
36+
super().__init__(buffer)
3737

3838
self._callback = callback
3939
self._tee = tee
@@ -44,11 +44,10 @@ def write(self, s: str) -> int:
4444
self._buffer.append(s)
4545
if self._tee:
4646
super().write(s)
47-
else:
48-
# If we're not teeing, we have to handle automatic flush on
49-
# newline. When `tee` is true, this is handled by the write method.
50-
if "\n" in s or "\r" in s:
51-
self.flush()
47+
48+
if "\n" in s or "\r" in s:
49+
self.flush()
50+
5251
return length
5352

5453
def flush(self) -> None:

python/cog/server/http.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,11 @@ async def start_shutdown() -> Any:
165165
return app
166166

167167
worker = make_worker(
168-
predictor_ref=cog_config.get_predictor_ref(mode=mode), is_async=is_async
168+
predictor_ref=cog_config.get_predictor_ref(mode=mode),
169+
is_async=is_async,
170+
max_concurrency=cog_config.max_concurrency,
169171
)
170-
runner = PredictionRunner(worker=worker)
172+
runner = PredictionRunner(worker=worker, max_concurrency=cog_config.max_concurrency)
171173

172174
class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)):
173175
pass
@@ -219,7 +221,7 @@ class TrainingRequest(
219221
response_model=TrainingResponse,
220222
response_model_exclude_unset=True,
221223
)
222-
def train(
224+
async def train(
223225
request: TrainingRequest = Body(default=None),
224226
prefer: Optional[str] = Header(default=None),
225227
traceparent: Optional[str] = Header(
@@ -232,7 +234,7 @@ def train(
232234
respond_async = prefer == "respond-async"
233235

234236
with trace_context(make_trace_context(traceparent, tracestate)):
235-
return _predict(
237+
return await _predict(
236238
request=request,
237239
response_type=TrainingResponse,
238240
respond_async=respond_async,
@@ -243,7 +245,7 @@ def train(
243245
response_model=TrainingResponse,
244246
response_model_exclude_unset=True,
245247
)
246-
def train_idempotent(
248+
async def train_idempotent(
247249
training_id: str = Path(..., title="Training ID"),
248250
request: TrainingRequest = Body(..., title="Training Request"),
249251
prefer: Optional[str] = Header(default=None),
@@ -280,7 +282,7 @@ def train_idempotent(
280282
respond_async = prefer == "respond-async"
281283

282284
with trace_context(make_trace_context(traceparent, tracestate)):
283-
return _predict(
285+
return await _predict(
284286
request=request,
285287
response_type=TrainingResponse,
286288
respond_async=respond_async,
@@ -359,7 +361,7 @@ async def predict(
359361
respond_async = prefer == "respond-async"
360362

361363
with trace_context(make_trace_context(traceparent, tracestate)):
362-
return _predict(
364+
return await _predict(
363365
request=request,
364366
response_type=PredictionResponse,
365367
respond_async=respond_async,
@@ -407,13 +409,13 @@ async def predict_idempotent(
407409
respond_async = prefer == "respond-async"
408410

409411
with trace_context(make_trace_context(traceparent, tracestate)):
410-
return _predict(
412+
return await _predict(
411413
request=request,
412414
response_type=PredictionResponse,
413415
respond_async=respond_async,
414416
)
415417

416-
def _predict(
418+
async def _predict(
417419
*,
418420
request: Optional[PredictionRequest],
419421
response_type: Type[schema.PredictionResponse],
@@ -455,7 +457,7 @@ def _predict(
455457
)
456458

457459
# Otherwise, wait for the prediction to complete...
458-
predict_task.wait()
460+
await predict_task.wait_async()
459461

460462
# ...and return the result.
461463
if PYDANTIC_V2:

0 commit comments

Comments
 (0)