Skip to content

Commit

Permalink
fix: ensure SDK generation succeeds (#260)
Browse files Browse the repository at this point in the history
This commit ensures that no Optional types are used in the FastAPI models. This change is necessary because the oapi-codegen tool (see issue #373) and Speakeasy do not yet support Optional types.
  • Loading branch information
rickstaa authored Nov 6, 2024
1 parent acf9b15 commit efae373
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 102 deletions.
18 changes: 9 additions & 9 deletions runner/app/pipelines/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ def list(cls):
return [model.value for model in cls]

@classmethod
def get(cls, model_id: str) -> Enum | None:
"""Return the enum or None if the model ID is not found."""
def from_value(cls, value: str) -> Enum | None:
"""Return the enum member corresponding to the given value, or None if not
found."""
try:
return cls(model_id)
return cls(value)
except ValueError:
return None

Expand All @@ -45,6 +46,7 @@ class ModelConfig:
torch.float16 if torch.cuda.is_available() else torch.float32
)
chunk_length_s: int = 30
batch_size: int = 16


MODEL_CONFIGS = {
Expand All @@ -63,10 +65,8 @@ def __init__(self, model_id: str):
torch_device = get_torch_device()

# Get model specific configuration parameters.
model_enum = ModelName.get(model_id)
self._model_cfg = (
ModelConfig() if model_enum is None else MODEL_CONFIGS[model_enum]
)
model_enum = ModelName.from_value(model_id)
self._model_cfg: ModelConfig = MODEL_CONFIGS.get(model_enum, ModelConfig())
kwargs["torch_dtype"] = self._model_cfg.torch_dtype
logger.info(
"AudioToText loading '%s' on device '%s' with '%s' variant",
Expand Down Expand Up @@ -110,7 +110,7 @@ def __call__(self, audio: UploadFile, duration: float, **kwargs) -> List[File]:

# Adjust batch size and chunk length based on timestamps and duration.
# NOTE: Done to prevent CUDA OOM errors for large audio files.
kwargs["batch_size"] = 16
kwargs["batch_size"] = self._model_cfg.batch_size
kwargs["chunk_length_s"] = self._model_cfg.chunk_length_s
if kwargs["return_timestamps"] == "word":
if duration > 3600:
Expand All @@ -123,7 +123,7 @@ def __call__(self, audio: UploadFile, duration: float, **kwargs) -> List[File]:
kwargs.pop("batch_size", None)
kwargs.pop("chunk_length_s", None)
inference_mode = "sequential"
else:
else:
inference_mode = f"chunked (batch_size={kwargs['batch_size']}, chunk_length_s={kwargs['chunk_length_s']})"
logger.info(
f"AudioToTextPipeline: Starting inference mode={inference_mode} with duration={duration}"
Expand Down
13 changes: 8 additions & 5 deletions runner/app/routes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ class MasksResponse(BaseModel):
class Chunk(BaseModel):
"""A chunk of text with a timestamp."""

timestamp: Tuple[Optional[float], Optional[float]] = Field(
..., description="The timestamp of the chunk."
)
timestamp: Tuple = Field(..., description="The timestamp of the chunk.")
text: str = Field(..., description="The text of the chunk.")


Expand All @@ -84,11 +82,16 @@ class ImageToTextResponse(BaseModel):

text: str = Field(..., description="The generated text.")


class LiveVideoToVideoResponse(BaseModel):
"""Response model for live video-to-video generation."""

subscribe_url: str = Field(..., description="Source URL of the incoming stream to subscribe to")
publish_url: str = Field(..., description="Destination URL of the outgoing stream to publish to")
subscribe_url: str = Field(
..., description="Source URL of the incoming stream to subscribe to"
)
publish_url: str = Field(
..., description="Destination URL of the outgoing stream to publish to"
)


class APIError(BaseModel):
Expand Down
10 changes: 1 addition & 9 deletions runner/gateway.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -843,16 +843,8 @@ components:
Chunk:
properties:
timestamp:
prefixItems:
- anyOf:
- type: number
- type: 'null'
- anyOf:
- type: number
- type: 'null'
items: {}
type: array
maxItems: 2
minItems: 2
title: Timestamp
description: The timestamp of the chunk.
text:
Expand Down
10 changes: 1 addition & 9 deletions runner/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -852,16 +852,8 @@ components:
Chunk:
properties:
timestamp:
prefixItems:
- anyOf:
- type: number
- type: 'null'
- anyOf:
- type: number
- type: 'null'
items: {}
type: array
maxItems: 2
minItems: 2
title: Timestamp
description: The timestamp of the chunk.
text:
Expand Down
140 changes: 70 additions & 70 deletions worker/runner.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit efae373

Please sign in to comment.