diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7cc26fd..0fad0aa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,5 +1,4 @@ name: CI - on: push: branches: @@ -9,20 +8,16 @@ on: - main schedule: - cron: "0 6 * * 1/2" # Every other day 6AM UTC - concurrency: group: ci-${{ github.head_ref || github.run_id }} cancel-in-progress: true - env: LINES: 200 COLUMNS: 200 - # https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#defaultsrun defaults: run: shell: bash --noprofile --norc -exo pipefail {0} - jobs: unit-tests: strategy: @@ -32,12 +27,10 @@ jobs: python-version: ["3.8"] exclude: - os: "windows-latest" - if: "!github.event.repository.fork" # Don't run on fork repository name: python${{ matrix.python-version }} integration tests (${{ matrix.os }}) runs-on: ${{ matrix.os }} timeout-minutes: 90 - steps: - name: Checkout uses: actions/checkout@v3 diff --git a/.github/workflows/scheduled-jobs.yml b/.github/workflows/scheduled-jobs.yml index e2e4c72..567831e 100644 --- a/.github/workflows/scheduled-jobs.yml +++ b/.github/workflows/scheduled-jobs.yml @@ -39,7 +39,6 @@ jobs: git submodule update --remote git add extern && git commit -S -m "cron: update submodules to latest commits [generated]" || true git push - run-formatter: if: github.repository_owner == 'aarnphm' runs-on: ubuntu-latest diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 1097da8..52bc747 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -1,9 +1,7 @@ name: style-check - concurrency: group: style-check-${{ github.head_ref || github.run_id }} cancel-in-progress: true - on: push: branches: @@ -13,16 +11,13 @@ on: - main schedule: - cron: "0 0 * * 1/2" # Every other day 12AM UTC - env: LINES: 200 COLUMNS: 200 - # https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#defaultsrun defaults: run: shell: bash --noprofile --norc -exo pipefail {0} - jobs: lint: runs-on: ubuntu-latest @@ -57,4 +52,5 @@ jobs: run: npm install -g npm@^7 pyright - name: Type check if: ${{ github.event_name == 'pull_request' }} - run: git diff --name-only --diff-filter=AM "origin/$GITHUB_BASE_REF" -z -- '*.py{,i}' | xargs -0 --no-run-if-empty pyright + run: git diff --name-only --diff-filter=AM "origin/$GITHUB_BASE_REF" -z -- + '*.py{,i}' | xargs -0 --no-run-if-empty pyright diff --git a/.github/workflows/update-nixpkgs.yml b/.github/workflows/update-nixpkgs.yml index cd08fe4..59c476c 100644 --- a/.github/workflows/update-nixpkgs.yml +++ b/.github/workflows/update-nixpkgs.yml @@ -4,7 +4,6 @@ on: schedule: # run this every day at 12:00AM UTC - cron: "0 0 1/7 * *" - jobs: niv-updater: name: "Create PR for niv-managed dependencies" diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 3758286..a7cbf06 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -1,5 +1,4 @@ name: wheels - on: workflow_dispatch: push: @@ -8,20 +7,16 @@ on: pull_request: branches: - main - concurrency: group: wheels-${{ github.head_ref || github.run_id }} cancel-in-progress: true - env: LINES: 200 COLUMNS: 200 - # https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#defaultsrun defaults: run: shell: bash --noprofile --norc -exo pipefail {0} - jobs: build-sdist: name: Build source distribution @@ -43,19 +38,63 @@ jobs: pip install build python -m build --sdist + - name: Test built sdist + if: ${{ github.event_name == 'pull_request' }} + run: | + python -m venv venv + source venv/bin/activate + pip install dist/*.tar.gz && python -c "import whispercpp as w;print(dir(w.api)); print(dir(w.audio));" - name: Upload to PyPI - if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} run: | pip install twine twine upload dist/* --repository pypi --verbose - + build-arm-wheels: + name: Build ARM wheels for ${{ matrix.python[1] }}-${{ matrix.platform[0] }} + runs-on: ${{ matrix.platform[1] }} + if: ${{ failure() }} # Disable this for now. We will tackle this later. + timeout-minutes: 90 + strategy: + fail-fast: false + matrix: + python: + - ["cp38", "3.8"] + - ["cp39", "3.9"] + - ["cp310", "3.10"] + - ["cp311", "3.11"] + platform: + - [manylinux_aarch64, ubuntu-latest] + - [macosx_arm64, macos-latest] + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 # fetch all tags and branches + - name: Setup CI + uses: ./.github/actions/setup-repo + with: + python-version: ${{ matrix.python[1] }} + - name: Set up QEMU + if: runner.os == 'Linux' + uses: docker/setup-qemu-action@v2 + with: + platforms: all + - name: Build wheels + uses: pypa/cibuildwheel@v2.12.0 + env: + CIBW_BUILD: ${{ matrix.python[0] }}-${{ matrix.platform[0] }} + - uses: actions/upload-artifact@v3 + with: + name: ${{ matrix.python[0] }}-${{ startsWith(matrix.platform[0], 'macosx') + && 'macosx' || matrix.platform[0] }} + path: ./wheelhouse/*.whl build-wheel: name: Build wheels for python${{ matrix.python-version }} (${{ matrix.os }}) runs-on: ${{ matrix.os }} - timeout-minutes: 20 + timeout-minutes: 90 if: github.repository_owner == 'aarnphm' # Don't run on fork repository strategy: fail-fast: false @@ -65,7 +104,6 @@ jobs: exclude: - os: "windows-latest" - python-version: "3.11.2" - steps: - name: Checkout uses: actions/checkout@v3 @@ -80,14 +118,17 @@ jobs: uses: docker/setup-qemu-action@v2 with: platforms: all - - name: Set up Clang [Linux] - if: runner.os == 'Linux' - uses: egor-tensin/setup-clang@v1 - name: Running update requirements run: bazel run pypi_update - name: Building wheels if: github.event_name == 'pull_request' run: bazel build whispercpp_wheel + - name: Test built wheel + if: ${{ github.event_name == 'pull_request' }} + run: | + python -m venv venv + source venv/bin/activate + pip install $(bazel info bazel-bin)/*.whl && python -c "import whispercpp as w;print(dir(w.api)); print(dir(w.audio));" - name: Retrieving versions if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') id: get-info diff --git a/.gitignore b/.gitignore index f6af5dc..b3f9e9b 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,6 @@ dmypy.json # bazel generated files bazel-* + +# bun lock file +bun.lockb diff --git a/BUILD.bazel b/BUILD.bazel index f1a3efb..5b1f12b 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1,5 +1,5 @@ load("@rules_python//python:versions.bzl", "gen_python_config_settings") -load("@com_github_bazelbuild_buildtools//buildifier:def.bzl", "buildifier") +load("@com_github_bazelbuild_buildtools//buildifier:def.bzl", "buildifier", "buildifier_test") load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") load("@rules_cc//cc:defs.bzl", "cc_library") load("@bazel_skylib//rules:write_file.bzl", "write_file") @@ -19,10 +19,24 @@ exports_files([ "LICENSE", "README.md", "README.rst", + "yarn.lock", ]) buildifier( - name = "buildifier", + name = "buildfmt", +) + +buildifier_test( + name = "buildcheck", + srcs = glob([ + "**/*.bzl", + "**/*.bazel", + ]), +) + +alias( + name = "pyright", + actual = "@npm//:node_modules/pyright/index.js", ) alias( @@ -49,7 +63,7 @@ COPTS = [ }) cc_library( - name = "audio", + name = "audio_lib", srcs = ["//src/whispercpp:audio.cc"], hdrs = [ "//src/whispercpp:audio.h", @@ -67,7 +81,7 @@ cc_library( ], }), deps = [ - ":context", + ":context_lib", "@com_github_ggerganov_whisper//:common", "@com_github_libsdl_sdl2//:SDL", "@com_github_libsdl_sdl2//:include", @@ -76,7 +90,7 @@ cc_library( ) cc_library( - name = "context", + name = "context_lib", srcs = [ "//src/whispercpp:context.cc", "//src/whispercpp:context.h", @@ -88,27 +102,40 @@ cc_library( copts = COPTS, defines = ["BAZEL_BUILD"], deps = [ + "@com_github_ggerganov_whisper//:common", "@com_github_ggerganov_whisper//:whisper", "@pybind11", ], ) pybind_extension( - name = "api", + name = "audio_cpp2py_export", srcs = [ - "//src/whispercpp:api_export.cc", - "//src/whispercpp:api_export.h", + "//src/whispercpp:audio.cc", "//src/whispercpp:audio.h", "//src/whispercpp:context.h", - "@com_github_ggerganov_whisper//:examples/common.h", "@com_github_ggerganov_whisper//:whisper.h", ], copts = COPTS, defines = ["BAZEL_BUILD"], deps = [ - ":audio", - ":context", + ":audio_lib", + ":context_lib", + ], +) + +pybind_extension( + name = "api_cpp2py_export", + srcs = [ + "//src/whispercpp:api_cpp2py_export.cc", + "//src/whispercpp:api_cpp2py_export.h", + "//src/whispercpp:context.h", + "@com_github_ggerganov_whisper//:examples/common.h", + "@com_github_ggerganov_whisper//:whisper.h", ], + copts = COPTS, + defines = ["BAZEL_BUILD"], + deps = [":context_lib"], ) write_file( @@ -117,14 +144,18 @@ write_file( content = [ "#!/usr/bin/env bash", "cd $BUILD_WORKSPACE_DIRECTORY", - "cp -fv bazel-bin/api.so src/whispercpp/api.so", + "cp -fv bazel-bin/api_cpp2py_export.so src/whispercpp/api_cpp2py_export.so", + "cp -fv bazel-bin/audio_cpp2py_export.so src/whispercpp/audio_cpp2py_export.so", ], ) sh_binary( name = "extensions", srcs = [":gen_extensions"], - data = [":api.so"], + data = [ + ":api_cpp2py_export.so", + ":audio_cpp2py_export.so", + ], ) # public exports @@ -195,8 +226,5 @@ py_wheel( ":ci": "{BUILD_EMBED_LABEL}", }), visibility = ["//:__subpackages__"], - deps = [ - ":api.so", - ":whispercpp_pkg", - ], + deps = [":whispercpp_pkg"], ) diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 0d91375..c459d71 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -9,12 +9,27 @@ Start new shell: nix-shell ``` -We are using [bazel](https://bazel.build/) as a build system. I also provided a `tools/bazel` script to make it easier to use. +We are using [bazel](https://bazel.build/) as a build system. I also provided a +`tools/bazel` script to make it easier to use. Build the extension: ```bash -./tools/bazel run //:extensions +./tools/bazel run extensions +``` + +To run all format, it is most convenient to use treefmt: + +```bash +nix-shell --command treefmt +``` + +Otherwise run `black`, `isort`, and `ruff` + +Whispercpp also use `pyright` for Python type check. To run it do: + +```bash +./tools/bazel run //:pyright ``` ### Testing @@ -22,7 +37,9 @@ Build the extension: Run tests: ```bash -./tools/bazel test //tests/... +./tools/bazel test tests/... examples/... ``` -> NOTE: Make sure to include the `extern/whispercpp`, `extern/pybind11/include`, and ```$(python3-config --prefix)/include/python3``` in your `CPLUS_INCLUDE_PATH` so that `clangd` can find the headers in your editor. +> NOTE: Make sure to include the `extern/whispercpp`, `extern/pybind11/include`, +> and `$(python3-config --prefix)/include/python3` in your `CPLUS_INCLUDE_PATH` +> so that `clangd` can find the headers in your editor. diff --git a/Makefile b/Makefile index 79490ba..4c72f59 100644 --- a/Makefile +++ b/Makefile @@ -29,13 +29,13 @@ endif context.o: src/whispercpp/context.cc src/whispercpp/context.h $(CXX) $(CXXFLAGS) -o src/whispercpp/context.o -c src/whispercpp/context.cc -api_export.o: src/whispercpp/api_export.cc - $(CXX) $(CXXFLAGS) -o src/whispercpp/api_export.o -c src/whispercpp/api_export.cc +api_cpp2py_export.o: src/whispercpp/api_cpp2py_export.cc + $(CXX) $(CXXFLAGS) -o src/whispercpp/api_cpp2py_export.o -c src/whispercpp/api_cpp2py_export.cc -api: api_export.o context.o +api: api_cpp2py_export.o context.o @echo "Building pybind11 extension..." @cd ./extern/whispercpp && $(MAKE) ggml.o whisper.o && mv whisper.o ggml.o ../../src/whispercpp/ - $(CXX) $(CXXFLAGS) $(EXTRA_CXXFLAGS) -shared -o src/whispercpp/api.so src/whispercpp/*.o + $(CXX) $(CXXFLAGS) $(EXTRA_CXXFLAGS) -shared -o src/whispercpp/api_cpp2py_export.so src/whispercpp/*.o clean: rm -rf **/*.o **/*.so diff --git a/WORKSPACE b/WORKSPACE index 842d984..1977f3e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -4,23 +4,9 @@ load("//rules:deps.bzl", "internal_deps") internal_deps() -load("@com_github_bentoml_plugins//rules:deps.bzl", "plugins_dependencies") +load("@rules_foreign_cc//foreign_cc:repositories.bzl", "rules_foreign_cc_dependencies") -plugins_dependencies() - -# NOTE: external users wish to use BentoML workspace setup -# should always be loaded in this order. -load("@com_github_bentoml_plugins//rules:workspace0.bzl", "workspace0") - -workspace0() - -load("@com_github_bentoml_plugins//rules:workspace1.bzl", "workspace1") - -workspace1() - -load("@com_github_bentoml_plugins//rules:workspace2.bzl", "workspace2") - -workspace2() +rules_foreign_cc_dependencies() load("@com_grail_bazel_toolchain//toolchain:deps.bzl", "bazel_toolchain_dependencies") @@ -58,6 +44,14 @@ load("@llvm_toolchain//:toolchains.bzl", "llvm_register_toolchains") llvm_register_toolchains() +load("@rules_python//python:repositories.bzl", "py_repositories", "python_register_multi_toolchains") + +py_repositories() + +load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependencies") + +pip_install_dependencies() + load("@rules_python//python:pip.bzl", "pip_parse") pip_parse( diff --git a/examples/bentoml/service.py b/examples/bentoml/service.py index d352b12..3f0d86f 100644 --- a/examples/bentoml/service.py +++ b/examples/bentoml/service.py @@ -36,6 +36,10 @@ def transcribe_file(self, p: str): raise FileNotFoundError(resolved) return self.model.transcribe_from_file(resolved) + @bentoml.Runnable.method(batchable=False) + def stream(self): + self.model.stream_transcribe() + cpp_runner = bentoml.Runner(WhisperCppRunnable, max_batch_size=30) diff --git a/examples/stream/stream.py b/examples/stream/stream.py index e99c639..a2dc88e 100644 --- a/examples/stream/stream.py +++ b/examples/stream/stream.py @@ -53,7 +53,7 @@ def main(**kwargs: t.Any): args = parser.parse_args() if args.list_audio_devices: - w.api.AudioCapture.list_available_devices() + w.utils.available_audio_devices() sys.exit(0) main(**vars(args)) diff --git a/package.json b/package.json new file mode 100644 index 0000000..0aaf605 --- /dev/null +++ b/package.json @@ -0,0 +1,10 @@ +{ + "name": "whispercpp-node-tools", + "version": "0.0.0", + "description": "JS tooling for whispercpp", + "author": "Aaron Pham", + "license": "Apache-2.0", + "dependencies": { + "pyright": "^1.1.296" + } +} diff --git a/pyproject.toml b/pyproject.toml index 3baa2f4..9b10cc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,7 +129,29 @@ skip_glob = [ [tool.pyright] pythonVersion = "3.11" -exclude = ["bazel-*", "extern", "venv"] +exclude = [ + "bazel-*", + "extern", + "venv", + "node_modules", + "nix", + "setup.py", + "examples/bentoml/locustfile.py", +] typeCheckingMode = "strict" analysis.useLibraryCodeForTypes = true enableTypeIgnoreComments = true + +[tool.cibuildwheel] +test-requires = "pytest" +test-command = "pytest {project}/tests" +build-verbosity = "3" +environment = { CC = "g++", CXX = "clang++" } + +[tool.cibuildwheel.macos] +archs = ["x86_64", "arm64"] +environment = { RUNNER_OS = "MacOS" } + +[tool.cibuildwheel.linux] +archs = ["x86_64", "aarch64"] +environment = { RUNNER_OS = "Linux" } diff --git a/rules/deps.bzl b/rules/deps.bzl index 1ba00ea..1fd4b51 100644 --- a/rules/deps.bzl +++ b/rules/deps.bzl @@ -18,6 +18,12 @@ def internal_deps(): commit = "9da166c0d5af5543f6084bf3ae5223ea19f0e7ea", shallow_since = "1678069830 -0800", ) + # NOTE: uncomment the below line for debugging rules + # change the path to absolute path + # native.local_repository( + # name = "com_github_bentoml_plugins", + # path = "/Users/aarnphm/workspace/bentoml/ecosystem/", + # ) maybe( http_archive, diff --git a/setup.py b/setup.py index 2d5b1b6..676b8f6 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ def update_submodules(directory: str): def compile_ext(): wd = path.dirname(path.abspath(__file__)) - if not path.exists(path.join(wd, "src", "whispercpp", "api.so")): + if not path.exists(path.join(wd, "src", "whispercpp", "api_cpp2py_export.so")): update_submodules(wd) print("Building pybind11 extension...") bazel_script = Path(wd) / "tools" / "bazel" diff --git a/src/whispercpp/BUILD b/src/whispercpp/BUILD index 0399941..4c6fd47 100644 --- a/src/whispercpp/BUILD +++ b/src/whispercpp/BUILD @@ -7,5 +7,8 @@ exports_files(glob(["*.cc"]) + glob(["*.h"]) + glob(["*.py"]) + glob(["*.pyi"])) py_library( name = "whispercpp_lib", srcs = glob(["*.py"]), - data = ["//:api.so"] + glob(["*.pyi"]), + data = [ + "//:api_cpp2py_export.so", + "//:audio_cpp2py_export.so", + ] + glob(["*.pyi"]), ) diff --git a/src/whispercpp/__init__.py b/src/whispercpp/__init__.py index 432a642..64a68f4 100644 --- a/src/whispercpp/__init__.py +++ b/src/whispercpp/__init__.py @@ -1,3 +1,22 @@ +""" +whispercpp: A Pybind11-based Python wrapper for whisper.cpp, a C++ implementation of OpenAI Whisper. + +Install with pip: + +.. code-block:: bash + + pip install whispercpp + +The binding provides a ``Whisper`` class: + +.. code-block:: python + + from whispercpp import Whisper + + w = Whisper.from_pretrained("tiny.en") + print(w.transcribe_from_file("test.wav")) +""" + from __future__ import annotations import typing as t @@ -9,14 +28,31 @@ import numpy as np from numpy.typing import NDArray - from . import api + # NOTE: We can safely ignore the following imports + # because they are only used for type checking. + from . import api # type: ignore + from . import audio # type: ignore else: - api = utils.LazyLoader("api", globals(), "whispercpp.api") + api = utils.LazyLoader("api", globals(), "whispercpp.api_cpp2py_export") + audio = utils.LazyLoader( + "audio", + globals(), + "whispercpp.audio_cpp2py_export", + exc_msg="Failed to import 'audio' extensions. Try to install whispercpp from source.", + ) @dataclass class Whisper: + """ + A wrapper class for Whisper C++ API. + + This class should only be instantiated using ``from_pretrained()``. + ``__init__()`` will raise a ``RuntimeError``. + """ + def __init__(self, *args: t.Any, **kwargs: t.Any): + """Empty init method. This will raise a ``RuntimeError``.""" raise RuntimeError( "Using '__init__()' is not allowed. Use 'from_pretrained()' instead." ) @@ -29,6 +65,21 @@ def __init__(self, *args: t.Any, **kwargs: t.Any): @staticmethod def from_pretrained(model_name: str, basedir: str | None = None): + """Load a preconverted model from a given model name. + + Currently it doesn't support custom preconverted models yet. PRs are welcome. + + Args: + model_name (str): Name of the preconverted model. + basedir (str, optional): Base directory to store the model. Defaults to None. + Default will be "$XDG_DATA_HOME/whispercpp" for directory. + + Returns: + A ``Whisper`` object. + + Raises: + RuntimeError: If the given model name is not a valid preconverted model. + """ if model_name not in utils.MODELS_URL: raise RuntimeError( f"'{model_name}' is not a valid preconverted model. Choose one of {list(utils.MODELS_URL)}" @@ -47,6 +98,16 @@ def from_pretrained(model_name: str, basedir: str | None = None): return _ref def transcribe(self, data: NDArray[np.float32], num_proc: int = 1): + """Transcribe audio from a given numpy array. + + Args: + data (np.ndarray): Audio data as a numpy array. + num_proc (int, optional): Number of processes to use for transcription. Defaults to 1. + Note that if num_proc > 1, transcription accuracy may decrease. + + Returns: + Transcribed text. + """ self.context.full_parallel(self.params, data, num_proc) return "".join( [ @@ -56,6 +117,20 @@ def transcribe(self, data: NDArray[np.float32], num_proc: int = 1): ) def transcribe_from_file(self, filename: str, num_proc: int = 1): + """Transcribe audio from a given file. This function uses a simple C++ implementation for loading audio file. + + Currently only WAV files are supported. PRs are welcome for other format supports. + + See ``Whisper.transcribe()`` for more details. + + Args: + filename (str): Path to the audio file. + num_proc (int, optional): Number of processes to use for transcription. Defaults to 1. + Note that if num_proc > 1, transcription accuracy may decrease. + + Returns: + Transcribed text. + """ return self.transcribe(api.load_wav_file(filename).mono, num_proc) def stream_transcribe( @@ -82,7 +157,7 @@ def stream_transcribe( if sample_rate is None: sample_rate = api.SAMPLE_RATE - ac = api.AudioCapture(length_ms) + ac = audio.AudioCapture(length_ms) if not ac.init_device(device_id, sample_rate): raise RuntimeError("Failed to initialize audio capture device.") @@ -100,4 +175,4 @@ def stream_transcribe( return ac.transcript -__all__ = ["Whisper", "api", "utils"] +__all__ = ["Whisper", "api", "utils", "audio"] diff --git a/src/whispercpp/__init__.pyi b/src/whispercpp/__init__.pyi index 06bb138..c86776f 100644 --- a/src/whispercpp/__init__.pyi +++ b/src/whispercpp/__init__.pyi @@ -5,6 +5,7 @@ from typing import Generator from typing import TYPE_CHECKING from . import api as api +from . import audio as audio from . import utils as utils if TYPE_CHECKING: diff --git a/src/whispercpp/api.pyi b/src/whispercpp/api.pyi index 6f6184f..236cc1b 100644 --- a/src/whispercpp/api.pyi +++ b/src/whispercpp/api.pyi @@ -114,27 +114,3 @@ class TokenData: def load_wav_file(filename: str) -> WavFile: ... def sdl_poll_events() -> bool: ... - -class AudioCapture: - transcript: list[str] - def __init__(self, length_ms: int) -> None: ... - @t.overload - def init_device(self) -> bool: ... - @t.overload - def init_device(self, device_id: int) -> bool: ... - @t.overload - def init_device(self, device_id: int, sample_rate: int) -> bool: ... - @staticmethod - def list_available_devices() -> list[int]: ... - @t.overload - def stream_transcribe( - self, context: Context, params: Params - ) -> t.Iterator[str]: ... - @t.overload - def stream_transcribe( - self, context: Context, params: Params, step_ms: int = ... - ) -> t.Iterator[str]: ... - def resume(self) -> bool: ... - def pause(self) -> bool: ... - def clear(self) -> bool: ... - def retrieve_audio(self, ms: int, audio: NDArray[np.float32]) -> None: ... diff --git a/src/whispercpp/api_export.cc b/src/whispercpp/api_cpp2py_export.cc similarity index 98% rename from src/whispercpp/api_export.cc rename to src/whispercpp/api_cpp2py_export.cc index df3e026..59b6efa 100644 --- a/src/whispercpp/api_export.cc +++ b/src/whispercpp/api_cpp2py_export.cc @@ -1,4 +1,4 @@ -#include "api_export.h" +#include "api_cpp2py_export.h" #include #include #include @@ -20,7 +20,7 @@ typedef std::function NewSegmentCallback; namespace whisper { -PYBIND11_MODULE(api, m) { +PYBIND11_MODULE(api_cpp2py_export, m) { m.doc() = "Python interface for whisper.cpp"; // NOTE: default attributes @@ -33,9 +33,6 @@ PYBIND11_MODULE(api, m) { // NOTE: export Context API ExportContextApi(m); - // NOTE: export AudioCapture API - ExportAudioApi(m); - m.def("load_wav_file", &WavFileWrapper::load_wav_file, "filename"_a, py::return_value_policy::reference); diff --git a/src/whispercpp/api_cpp2py_export.h b/src/whispercpp/api_cpp2py_export.h new file mode 100644 index 0000000..2951ff6 --- /dev/null +++ b/src/whispercpp/api_cpp2py_export.h @@ -0,0 +1,73 @@ +#pragma once + +#ifdef BAZEL_BUILD +#include "context.h" +#include "examples/common.h" +#include "pybind11/functional.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#else +#include "common.h" +#include "context.h" +#include "pybind11/functional.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#endif + +namespace py = pybind11; + +namespace whisper { +// std::make_unique for C++11 +// https://stackoverflow.com/a/17902439/8643197 +template struct _Unique_if { + typedef std::unique_ptr _Single_object; +}; + +template struct _Unique_if { + typedef std::unique_ptr _Unknown_bound; +}; + +template struct _Unique_if { + typedef void _Known_bound; +}; + +template +typename _Unique_if::_Single_object make_unique(Args &&...args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +template +typename _Unique_if::_Unknown_bound make_unique(size_t n) { + typedef typename std::remove_extent::type U; + return std::unique_ptr(new U[n]()); +} + +template +typename _Unique_if::_Known_bound make_unique(Args &&...) = delete; + +// Some black magic to make zero-copy numpy array +// See https://github.com/pybind/pybind11/issues/1042#issuecomment-642215028 +template +inline py::array_t as_pyarray(Sequence &&seq) { + auto size = seq.size(); + auto data = seq.data(); + std::unique_ptr seq_ptr = + whisper::make_unique(std::move(seq)); + auto capsule = py::capsule(seq_ptr.get(), [](void *p) { + std::unique_ptr(reinterpret_cast(p)); + }); + seq_ptr.release(); + return py::array(size, data, capsule); +} +} // namespace whisper + +struct WavFileWrapper { + py::array_t mono; + std::vector> stereo; + + WavFileWrapper(std::vector *mono, + std::vector> *stereo) + : mono(whisper::as_pyarray(std::move(*mono))), stereo(*stereo){}; + + static WavFileWrapper load_wav_file(const char *filename); +}; diff --git a/src/whispercpp/api_export.h b/src/whispercpp/api_export.h deleted file mode 100644 index f4ea809..0000000 --- a/src/whispercpp/api_export.h +++ /dev/null @@ -1,21 +0,0 @@ -#pragma once - -#include "audio.h" -#ifdef BAZEL_BUILD -#include -#else -#include -#endif - -namespace py = pybind11; - -struct WavFileWrapper { - py::array_t mono; - std::vector> stereo; - - WavFileWrapper(std::vector *mono, - std::vector> *stereo) - : mono(whisper::as_pyarray(std::move(*mono))), stereo(*stereo){}; - - static WavFileWrapper load_wav_file(const char *filename); -}; diff --git a/src/whispercpp/audio.cc b/src/whispercpp/audio.cc index 53a8679..1e2cdc8 100644 --- a/src/whispercpp/audio.cc +++ b/src/whispercpp/audio.cc @@ -7,6 +7,11 @@ using namespace pybind11::literals; namespace whisper { +PYBIND11_MODULE(audio_cpp2py_export, m) { + m.doc() = "Experimental: Audio capture API"; + ExportAudioApi(m); +} + AudioCapture::~AudioCapture() { if (m_dev_id) { SDL_CloseAudioDevice(m_dev_id); diff --git a/src/whispercpp/audio.h b/src/whispercpp/audio.h index e7ea849..7472a9f 100644 --- a/src/whispercpp/audio.h +++ b/src/whispercpp/audio.h @@ -6,14 +6,12 @@ #include "context.h" #include "examples/common.h" #include "pybind11/numpy.h" -#include "pybind11/pybind11.h" #else #include "SDL.h" #include "SDL_audio.h" -#include "common.h" #include "context.h" +#include "examples/common.h" #include "pybind11/numpy.h" -#include "pybind11/pybind11.h" #endif #include @@ -26,50 +24,6 @@ namespace py = pybind11; namespace whisper { - -// std::make_unique for C++11 -// https://stackoverflow.com/a/17902439/8643197 -template struct _Unique_if { - typedef std::unique_ptr _Single_object; -}; - -template struct _Unique_if { - typedef std::unique_ptr _Unknown_bound; -}; - -template struct _Unique_if { - typedef void _Known_bound; -}; - -template -typename _Unique_if::_Single_object make_unique(Args &&...args) { - return std::unique_ptr(new T(std::forward(args)...)); -} - -template -typename _Unique_if::_Unknown_bound make_unique(size_t n) { - typedef typename std::remove_extent::type U; - return std::unique_ptr(new U[n]()); -} - -template -typename _Unique_if::_Known_bound make_unique(Args &&...) = delete; - -// Some black magic to make zero-copy numpy array -// See https://github.com/pybind/pybind11/issues/1042#issuecomment-642215028 -template -inline py::array_t as_pyarray(Sequence &&seq) { - auto size = seq.size(); - auto data = seq.data(); - std::unique_ptr seq_ptr = - whisper::make_unique(std::move(seq)); - auto capsule = py::capsule(seq_ptr.get(), [](void *p) { - std::unique_ptr(reinterpret_cast(p)); - }); - seq_ptr.release(); - return py::array(size, data, capsule); -} - class AudioCapture { public: AudioCapture(int length_ms) { diff --git a/src/whispercpp/audio.pyi b/src/whispercpp/audio.pyi new file mode 100644 index 0000000..e1e9fe8 --- /dev/null +++ b/src/whispercpp/audio.pyi @@ -0,0 +1,33 @@ +from __future__ import annotations + +import typing as t + +import numpy as np +from numpy.typing import NDArray + +from .api import Params +from .api import Context + +class AudioCapture: + transcript: list[str] + def __init__(self, length_ms: int) -> None: ... + @t.overload + def init_device(self) -> bool: ... + @t.overload + def init_device(self, device_id: int) -> bool: ... + @t.overload + def init_device(self, device_id: int, sample_rate: int) -> bool: ... + @staticmethod + def list_available_devices() -> list[int]: ... + @t.overload + def stream_transcribe( + self, context: Context, params: Params + ) -> t.Iterator[str]: ... + @t.overload + def stream_transcribe( + self, context: Context, params: Params, step_ms: int = ... + ) -> t.Iterator[str]: ... + def resume(self) -> bool: ... + def pause(self) -> bool: ... + def clear(self) -> bool: ... + def retrieve_audio(self, ms: int, audio: NDArray[np.float32]) -> None: ... diff --git a/src/whispercpp/utils.py b/src/whispercpp/utils.py index 7359ba4..9749f11 100644 --- a/src/whispercpp/utils.py +++ b/src/whispercpp/utils.py @@ -117,6 +117,6 @@ def __dir__(self) -> list[str]: @lru_cache(maxsize=1) def available_audio_devices() -> list[int]: """Returns a list of available audio devices on the system.""" - from whispercpp import api + from whispercpp import audio # type: ignore - return api.AudioCapture.list_available_devices() + return audio.AudioCapture.list_available_devices() diff --git a/tests/BUILD b/tests/BUILD index 79dead8..1c6965e 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -5,7 +5,7 @@ py_test( name = "export", srcs = ["export_test.py"], data = [ - "//:api.so", + "//:api_cpp2py_export.so", "//samples:jfk.wav", ], deps = [ @@ -19,7 +19,7 @@ py_test( py_test( name = "utils", srcs = ["utils_test.py"], - data = ["//:api.so"], + data = ["//:api_cpp2py_export.so"], deps = [ "//src/whispercpp:whispercpp_lib", requirement("bazel-runfiles"), diff --git a/tests/export_test.py b/tests/export_test.py index a6b363b..33940db 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -6,17 +6,15 @@ import pytest as p -from whispercpp import api -from whispercpp import Whisper -from whispercpp.utils import LazyLoader +import whispercpp as w if t.TYPE_CHECKING: import numpy as np import ffmpeg from numpy.typing import NDArray else: - np = LazyLoader("np", globals(), "numpy") - ffmpeg = LazyLoader("ffmpeg", globals(), "ffmpeg") + np = w.utils.LazyLoader("np", globals(), "numpy") + ffmpeg = w.utils.LazyLoader("ffmpeg", globals(), "ffmpeg") ROOT = Path(__file__).parent.parent @@ -38,12 +36,12 @@ def preprocess(file: Path, sample_rate: int = 16000) -> NDArray[np.float32]: def test_invalid_models(): with p.raises(RuntimeError): - Whisper.from_pretrained("whisper_v0.1") + w.Whisper.from_pretrained("whisper_v0.1") def test_forbid_init(): with p.raises(RuntimeError): - Whisper() + w.Whisper() _EXPECTED = " And so my fellow Americans ask not what your country can do for you ask what you can do for your country" @@ -51,7 +49,7 @@ def test_forbid_init(): @p.mark.skipif(not s.which("ffmpeg"), reason="ffmpeg not found, skipping this test.") def test_from_pretrained(): - m = Whisper.from_pretrained("tiny.en") + m = w.Whisper.from_pretrained("tiny.en") assert _EXPECTED == m.transcribe(preprocess(ROOT / "samples" / "jfk.wav")) @@ -59,14 +57,14 @@ def test_from_pretrained(): def test_load_wav_file(): np.testing.assert_almost_equal( preprocess(ROOT / "samples" / "jfk.wav"), - api.load_wav_file( + w.api.load_wav_file( ROOT.joinpath("samples", "jfk.wav").resolve().__fspath__() ).mono, ) def test_transcribe_from_wav(): - m = Whisper.from_pretrained("tiny.en") + m = w.Whisper.from_pretrained("tiny.en") assert ( m.transcribe_from_file( ROOT.joinpath("samples", "jfk.wav").resolve().__fspath__() @@ -76,14 +74,14 @@ def test_transcribe_from_wav(): def test_callback(): - def handleNewSegment(context: api.Context, n_new: int, text: list[str]): + def handleNewSegment(context: w.api.Context, n_new: int, text: list[str]): segment = context.full_n_segments() - n_new while segment < context.full_n_segments(): text.append(context.full_get_segment_text(segment)) print(text) segment += 1 - m = Whisper.from_pretrained("tiny.en") + m = w.Whisper.from_pretrained("tiny.en") text = [] m.params.on_new_segment(handleNewSegment, text) diff --git a/yarn.lock b/yarn.lock new file mode 100644 index 0000000..45323ff --- /dev/null +++ b/yarn.lock @@ -0,0 +1,8 @@ +# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY. +# yarn lockfile v1 + + +pyright@^1.1.296: + version "1.1.296" + resolved "https://registry.yarnpkg.com/pyright/-/pyright-1.1.296.tgz#5bb4322b425d576e7a27882969123510d06023ca" + integrity sha512-T04flbRRbbzp37X4fdb8FedoavLYAY2pk5x3jxadnBXYssJn0WQCVWeIFryWGDe6v+a3iOrxbMt2EawNKkNVHQ==