|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | + |
| 4 | +""" |
| 5 | +Hugging Face Model |
| 6 | +
|
| 7 | +See [automatic-speech-recognition](https://huggingface.co/tasks/automatic-speech-recognition) |
| 8 | +""" |
| 9 | + |
| 10 | +import pysubs2 |
| 11 | +from pysubs2 import SSAFile, SSAEvent |
| 12 | +from subsai.models.abstract_model import AbstractModel |
| 13 | +from subsai.utils import _load_config, get_available_devices |
| 14 | + |
| 15 | +from transformers import pipeline |
| 16 | + |
| 17 | + |
| 18 | +devices = get_available_devices() |
| 19 | + |
| 20 | +class HuggingFaceModel(AbstractModel): |
| 21 | + model_name = 'HuggingFaceModel' |
| 22 | + config_schema = { |
| 23 | + # load model config |
| 24 | + 'model_id': { |
| 25 | + 'type': str, |
| 26 | + 'description': 'The model id from the Hugging Face Hub.', |
| 27 | + 'options': None, |
| 28 | + 'default': 'openai/whisper-tiny' |
| 29 | + }, |
| 30 | + 'device': { |
| 31 | + 'type': list, |
| 32 | + 'description': 'Pytorch device', |
| 33 | + 'options': devices, |
| 34 | + 'default': devices[0] |
| 35 | + }, |
| 36 | + 'segment_type': { |
| 37 | + 'type': list, |
| 38 | + 'description': "Sentence-level or word-level timestamps", |
| 39 | + 'options': ['sentence', 'word'], |
| 40 | + 'default': 'sentence' |
| 41 | + }, |
| 42 | + 'chunk_length_s': { |
| 43 | + 'type': float, |
| 44 | + 'description': '(`float`, *optional*, defaults to 0):' |
| 45 | + 'The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default).', |
| 46 | + 'options': None, |
| 47 | + 'default': 30 |
| 48 | + } |
| 49 | + } |
| 50 | + |
| 51 | + def __init__(self, model_config): |
| 52 | + super(HuggingFaceModel, self).__init__(model_config=model_config, |
| 53 | + model_name=self.model_name) |
| 54 | + # config |
| 55 | + self._model_id = _load_config('model_id', model_config, self.config_schema) |
| 56 | + self._device = _load_config('device', model_config, self.config_schema) |
| 57 | + self.segment_type = _load_config('segment_type', model_config, self.config_schema) |
| 58 | + self._chunk_length_s = _load_config('chunk_length_s', model_config, self.config_schema) |
| 59 | + |
| 60 | + |
| 61 | + self.model = pipeline( |
| 62 | + "automatic-speech-recognition", |
| 63 | + model=self._model_id, |
| 64 | + device=self._device, |
| 65 | + ) |
| 66 | + |
| 67 | + def transcribe(self, media_file): |
| 68 | + results = self.model( |
| 69 | + media_file, |
| 70 | + chunk_length_s=self._chunk_length_s, |
| 71 | + return_timestamps=True if self.segment_type == 'sentence' else 'word', |
| 72 | + ) |
| 73 | + subs = SSAFile() |
| 74 | + for chunk in results['chunks']: |
| 75 | + event = SSAEvent(start=pysubs2.make_time(s=chunk['timestamp'][0]), |
| 76 | + end=pysubs2.make_time(s=chunk['timestamp'][1])) |
| 77 | + event.plaintext = chunk['text'] |
| 78 | + subs.append(event) |
| 79 | + return subs |
0 commit comments