Skip to content

Commit

Permalink
style: apply formatting changes
Browse files Browse the repository at this point in the history
  • Loading branch information
winstxnhdw committed Jun 3, 2024
1 parent 413ca25 commit 7700365
Show file tree
Hide file tree
Showing 11 changed files with 19 additions and 34 deletions.
3 changes: 1 addition & 2 deletions capgen/transcriber/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ class Converter:
as_srt(segments: Iterable[Segment]) -> str:
converts transcription segments into a SRT file
"""

def __init__(self, segments: Iterable[Segment]):
self.segments = segments


def to_srt(self, segments: Iterable[Segment]) -> str:
"""
Summary
Expand All @@ -48,7 +48,6 @@ def to_srt(self, segments: Iterable[Segment]) -> str:
for segment in segments
)


def to_vtt(self, segments: Iterable[Segment]) -> str:
"""
Summary
Expand Down
1 change: 1 addition & 0 deletions capgen/types/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class Arguments(NamedTuple):
output (str) : the output file path
cuda (bool) : whether to use CUDA for inference
"""

file: str | BinaryIO
caption: Literal['srt', 'vtt']
output: str
Expand Down
1 change: 1 addition & 0 deletions capgen/types/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class TranscriberOptions(TypedDict, total=False):
number_of_threads (int) : the number of CPU threads
number_of_workers (int) : the number of workers
"""

device: Literal['auto', 'cpu', 'cuda']
number_of_threads: int
number_of_workers: int
5 changes: 2 additions & 3 deletions server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Framework(FastAPI):
initialise_routes()
dynamically initialise all routes
"""

def convert_delimiters(self, string: str, old: str, new: str) -> str:
"""
Summary
Expand All @@ -47,7 +48,6 @@ def convert_delimiters(self, string: str, old: str, new: str) -> str:
"""
return new.join(string.split(old))


def initialise_routes(self, api_directory: str):
"""
Summary
Expand All @@ -66,8 +66,7 @@ def initialise_routes(self, api_directory: str):
]

module_names = [
import_module(self.convert_delimiters(file_name[:-3], sep, '.')).__name__
for file_name in module_file_names
import_module(self.convert_delimiters(file_name[:-3], sep, '.')).__name__ for file_name in module_file_names
]

for module_name in module_names:
Expand Down
2 changes: 1 addition & 1 deletion server/api/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from fastapi import APIRouter

v1 = APIRouter(prefix='/v1', tags=["v1"])
v1 = APIRouter(prefix='/v1', tags=['v1'])
1 change: 1 addition & 0 deletions server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Config(BaseSettings):
worker_count (int) : the number of workers to use
use_cuda (bool) : whether to use CUDA for inference
"""

server_port: int = 49494
server_root_path: str = '/api'
worker_count: int = 1
Expand Down
5 changes: 1 addition & 4 deletions server/lifespans/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,4 @@ async def load_model():
-------
download and load the model
"""
await get_running_loop().run_in_executor(
None,
Transcriber.load
)
await get_running_loop().run_in_executor(None, Transcriber.load)
10 changes: 4 additions & 6 deletions server/middlewares/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@ class LoggingMiddleware:
logger (Logger) : a custom logger
app (ASGIApp) : the ASGI application
"""
def __init__(self, app: ASGIApp):

def __init__(self, app: ASGIApp):
getLogger('uvicorn.access').setLevel(WARN)
self.logger = getLogger('custom.access')
self.logger.setLevel(INFO)
self.logger.addHandler(StreamHandler())
self.app = app


async def inner_send(self, message: Message, send: Send, status_code: list[int]):
"""
Summary
Expand All @@ -42,7 +41,6 @@ async def inner_send(self, message: Message, send: Send, status_code: list[int])

await send(message)


def inner_send_factory(self, send: Send, status_code: list[int]) -> Callable[[Message], Awaitable[None]]:
"""
Summary
Expand All @@ -60,7 +58,6 @@ def inner_send_factory(self, send: Send, status_code: list[int]) -> Callable[[Me
"""
return lambda message: self.inner_send(message, send, status_code)


async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope['type'] != 'http':
return await self.app(scope, receive, send)
Expand All @@ -76,12 +73,13 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
await self.app(scope, receive, self.inner_send_factory(send, status_code))

finally:
self.logger.info('[%s] [INFO] %d "%s %s" %s "%s" in %.4f ms',
self.logger.info(
'[%s] [INFO] %d "%s %s" %s "%s" in %.4f ms',
strftime('%Y-%m-%d %H:%M:%S %z'),
status_code[0],
scope['method'],
scope['path'],
client_ip,
user_agent,
(process_time() - start_process_time) * 1000
(process_time() - start_process_time) * 1000,
)
7 changes: 2 additions & 5 deletions server/schemas/v1/transcribed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,5 @@ class Transcribed(BaseModel):
----------
result (str) : the transcribed text in the chosen caption file format
"""
result: str = Field(examples=[
'1\n'
'00:00:00,000 --> 00:00:02,000\n'
'Hello world.'
])

result: str = Field(examples=['1\n' '00:00:00,000 --> 00:00:02,000\n' 'Hello world.'])
6 changes: 3 additions & 3 deletions server/typings/starlette/types.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ class HttpStartMessage(TypedDict):
status (int) : the response status code
headers (list[tuple[bytes, bytes]]) : the request headers
"""

type: Literal['http.response.start']
status: int
headers: list[tuple[bytes, bytes]]


class HttpBodyMessage(TypedDict):
"""
Summary
Expand All @@ -36,10 +36,10 @@ class HttpBodyMessage(TypedDict):
type (Literal['http.response.body']) : the type of the message
body (bytes) : the message body
"""

type: Literal['http.response.body']
body: bytes


class Scope(TypedDict):
"""
Summary
Expand All @@ -62,6 +62,7 @@ class Scope(TypedDict):
query_string (bytes) : the query string parameters
app (Starlette) : the application object
"""

type: Literal['http', 'websocket', 'lifespan']
asgi: Mapping[str, str]
http_version: str
Expand All @@ -76,7 +77,6 @@ class Scope(TypedDict):
query_string: bytes
app: Starlette


type Message = HttpStartMessage | HttpBodyMessage
type Send = Callable[[Message], Awaitable[None]]
type Receive = Callable[[], Awaitable[Message]]
Expand Down
12 changes: 2 additions & 10 deletions tests/test_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,13 @@ def client() -> Generator[TestClient, None, None]:

def test_transcribe_srt(client: TestClient):
with open('tests/test.mp3', 'rb') as file:
response = client.post(
'/v1/transcribe',
files={ 'request': file },
params={ 'caption_format': 'srt' }
).json()
response = client.post('/v1/transcribe', files={'request': file}, params={'caption_format': 'srt'}).json()

assert response['result'] == '1\n00:00:00,000 --> 00:00:01,720\nHello there. My name is Bella.'


def test_transcribe_vtt(client: TestClient):
with open('tests/test.mp3', 'rb') as file:
response = client.post(
'/v1/transcribe',
files={ 'request': file },
params={ 'caption_format': 'vtt' }
).json()
response = client.post('/v1/transcribe', files={'request': file}, params={'caption_format': 'vtt'}).json()

assert response['result'] == 'WEBVTT\n\n00:00:00.000 --> 00:00:01.720\nHello there. My name is Bella.'

0 comments on commit 7700365

Please sign in to comment.